以前の記事でソフトマックス回帰、またMLP(多層パーセプトロン)でMNISTの0から9までの手書き数字の画像データセットを分類する記事を書きました。
[機械学習] ロジスティック回帰(およびソフトマックス回帰)でMNISTの分類
[機械学習/深層学習] 多層パーセプトロンを実装してMNISTの分類
今回はCNN(畳み込みニューラルネットワーク)を実装し、同様にMNISTの学習、分類をした上で、最後に以前の2つとどんな違いがあったか確認したいと思います。
前提
- 今回はPytorchを使用
- 実行環境はGoogle Colab。ランタイムはPython3(T4 GPU)を使用
※ 参照:機械学習・深層学習を勉強する際の検証用環境について - 本記事のコード全容はこちらからダウンロード可能。ipynbファイルであり、そのまま自身のGoogle Driveにアップロードして実行可能
- 数学的知識や用語の説明について、参考文献やリンクを最下部に掲載 (本記事内で詳細には解説しませんが、流れや実施内容がわかるようにしたいと思います)
全体の流れ
大きく分けると 7ステップ になります。
- データ前処理・読み込み
- Dataset / DataLoader の準備
- CNNモデルの定義
- 順伝播(forward)
- 損失関数・最適化手法
- 学習ループ
- 結果の可視化(正解・不正解)
実装
1. データ前処理・読み込み
MNIST データセットに対してテンソル変換および正規化を行い、ニューラルネットワークで学習可能な入力表現へ変換します。
- ToTensor()
- 画像を (H, W) → (C, H, W) に変換
- 値を [0, 255] → [0, 1] に正規化
- Normalize((0.5,), (0.5,))
- [0,1] → [-1,1] にスケーリング
- 学習を安定させる(勾配が暴れにくい)
※ 「正規化」で入力のスケールを揃える
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.ToTensor(), # [0,255] → [0,1]
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(
root="./data", train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST(
root="./data", train=False, download=True, transform=transform
)
2. Dataset / DataLoader の準備
訓練用データと評価用データを分離し、ミニバッチ学習が可能な形でデータローダを構築します。
- Dataset:
- 画像+ラベルの集合体
- DataLoader:
- ミニバッチ化
- シャッフル
- GPU転送しやすくする
※ 「for x, t in train_loader」だけで学習が回せる状態を作っている
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
3. CNNモデルの定義
畳み込み層・プーリング層・全結合層からなる畳み込みニューラルネットワークを定義します。
- 畳み込み層
- 局所特徴(線・角・丸み)を抽出
- 重み共有 → 位置ずれに強い
- プーリング層
- 空間サイズを半分に
- 細かい位置情報を捨てる
- 歪み・ズレ耐性を獲得
- 全結合層
- 抽出した特徴を使って最終判断
- クラス分類器の役割
class CNN(nn.Module):
def __init__(self):
super().__init__()
# 畳み込み層
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
# プーリング層
self.pool = nn.MaxPool2d(2, 2)
# 全結合層
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
4. 順伝播(forward)
入力画像から出力スコア(ロジット)を計算する一連の処理を定義します。
流れ
画像
↓
畳み込み + ReLU
↓
プーリング
↓
畳み込み + ReLU
↓
プーリング
↓
Flatten
↓
全結合
↓
クラススコア
def forward(self, x):
x = torch.relu(self.conv1(x)) # (B, 32, 28, 28)
x = self.pool(x) # (B, 32, 14, 14)
x = torch.relu(self.conv2(x)) # (B, 64, 14, 14)
x = self.pool(x) # (B, 64, 7, 7)
x = x.view(x.size(0), -1) # flatten
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
5. 損失関数・最適化手法
モデルの予測と正解ラベルの誤差を定量化し、パラメータ更新方法を定義します。
- CrossEntropyLoss
- softmax + log + NLL をまとめたもの
- Adam
- 学習率を自動調整
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
6. 学習ループ
誤差逆伝播法によりネットワークの重みを反復的に更新します。
- 各epochでやっていること
- 順伝播
- 損失計算
- 逆伝播(勾配計算)
- パラメータ更新
for epoch in range(10):
model.train()
for x, t in train_loader:
x, t = x.to(device), t.to(device)
optimizer.zero_grad()
y = model(x)
loss = criterion(y, t)
loss.backward()
optimizer.step()
# 検証
model.eval()
correct = 0
total = 0
with torch.no_grad():
for x, t in test_loader:
x, t = x.to(device), t.to(device)
y = model(x)
pred = y.argmax(dim=1)
correct += (pred == t).sum().item()
total += t.size(0)
acc = correct / total
print(f"Epoch {epoch+1}: Test Accuracy = {acc:.4f}")
7. 結果の可視化(正解・不正解)
テストデータを用いて汎化性能を評価し、正解・不正解例を可視化してモデルの振る舞いを分析します。
import matplotlib.pyplot as plt
import numpy as np
def collect_correct_incorrect(model, dataloader, device, max_samples=20):
model.eval()
correct = []
incorrect = []
with torch.no_grad():
for x, t in dataloader:
x, t = x.to(device), t.to(device)
y = model(x)
pred = y.argmax(dim=1)
for i in range(x.size(0)):
img = x[i].cpu().squeeze().numpy()
true = t[i].item()
p = pred[i].item()
if true == p and len(correct) < max_samples:
correct.append((img, true, p))
elif true != p and len(incorrect) < max_samples:
incorrect.append((img, true, p))
if len(correct) >= max_samples and len(incorrect) >= max_samples:
return correct, incorrect
return correct, incorrect
def show_correct_incorrect(correct, incorrect, n=10):
fig = plt.figure(figsize=(12, 4))
# 正解例
for i, (img, t, p) in enumerate(correct[:n]):
ax = fig.add_subplot(2, n, i + 1)
ax.imshow(img, cmap="gray")
ax.set_title(f"✓ T:{t} P:{p}", fontsize=9)
ax.axis("off")
# 不正解例
for i, (img, t, p) in enumerate(incorrect[:n]):
ax = fig.add_subplot(2, n, n + i + 1)
ax.imshow(img, cmap="gray")
ax.set_title(f"✗ T:{t} P:{p}", fontsize=9)
ax.axis("off")
plt.suptitle("MNIST CNN Classification Results")
plt.tight_layout()
plt.show()
correct, incorrect = collect_correct_incorrect(
model,
test_loader,
device,
max_samples=20
)
show_correct_incorrect(correct, incorrect, n=10)
最後に
分類結果の画像を見ると人間でも判別がなかなか難しいものは誤りがあるが、例えば傾いた数字の9など分類できています。
これは非常に良いモデル挙動と言えるかと思います。入力をX次元のベクトルと見るMLPに比べると、正確に判別できているようです。
MLPとCNNの決定的な違いを表にまとめます。
| 観点 | MLP | CNN |
|---|---|---|
| 入力 | ベクトル | 画像 |
| 空間構造 | 失う | 保つ |
| 崩れ耐性 | 弱い | 強い |
| 特徴 | 人任せ | 自動抽出 |
参考文献、リンク
- ゼロからつくるPython機械学習プログラミング入門
-
詳解ディープラーニング第2版
※ 詳解とありますが、入門的な内容から丁寧に解説してあります。 -
YouTubeチャンネル - 予備校のノリで学ぶ「大学の数学・物理」
※ 数学的知識の学習としては、世界一わかりやすかったです。

