0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

[機械学習/深層学習] 畳み込みニューラルネットワークを実装してMNISTの分類

Last updated at Posted at 2026-01-05

 以前の記事でソフトマックス回帰、またMLP(多層パーセプトロン)でMNISTの0から9までの手書き数字の画像データセットを分類する記事を書きました。
[機械学習] ロジスティック回帰(およびソフトマックス回帰)でMNISTの分類
[機械学習/深層学習] 多層パーセプトロンを実装してMNISTの分類

今回はCNN(畳み込みニューラルネットワーク)を実装し、同様にMNISTの学習、分類をした上で、最後に以前の2つとどんな違いがあったか確認したいと思います。

前提

  • 今回はPytorchを使用
  • 実行環境はGoogle Colab。ランタイムはPython3(T4 GPU)を使用
     ※ 参照:機械学習・深層学習を勉強する際の検証用環境について
  • 本記事のコード全容はこちらからダウンロード可能。ipynbファイルであり、そのまま自身のGoogle Driveにアップロードして実行可能
  • 数学的知識や用語の説明について、参考文献やリンクを最下部に掲載 (本記事内で詳細には解説しませんが、流れや実施内容がわかるようにしたいと思います)

全体の流れ

 大きく分けると 7ステップ になります。

  1. データ前処理・読み込み
  2. Dataset / DataLoader の準備
  3. CNNモデルの定義
  4. 順伝播(forward)
  5. 損失関数・最適化手法
  6. 学習ループ
  7. 結果の可視化(正解・不正解)

実装

1. データ前処理・読み込み

 MNIST データセットに対してテンソル変換および正規化を行い、ニューラルネットワークで学習可能な入力表現へ変換します。

  • ToTensor()
    • 画像を (H, W) → (C, H, W) に変換
    • 値を [0, 255] → [0, 1] に正規化
  • Normalize((0.5,), (0.5,))
    • [0,1] → [-1,1] にスケーリング
    • 学習を安定させる(勾配が暴れにくい)
      ※ 「正規化」で入力のスケールを揃える
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),          # [0,255] → [0,1]
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)

2. Dataset / DataLoader の準備

 訓練用データと評価用データを分離し、ミニバッチ学習が可能な形でデータローダを構築します。

  • Dataset:
    • 画像+ラベルの集合体
  • DataLoader:
    • ミニバッチ化
    • シャッフル
    • GPU転送しやすくする
      ※ 「for x, t in train_loader」だけで学習が回せる状態を作っている
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

3. CNNモデルの定義

 畳み込み層・プーリング層・全結合層からなる畳み込みニューラルネットワークを定義します。

  • 畳み込み層
    • 局所特徴(線・角・丸み)を抽出
    • 重み共有 → 位置ずれに強い
  • プーリング層
    • 空間サイズを半分に
    • 細かい位置情報を捨てる
    • 歪み・ズレ耐性を獲得
  • 全結合層
    • 抽出した特徴を使って最終判断
    • クラス分類器の役割
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        # 畳み込み層
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        # プーリング層
        self.pool = nn.MaxPool2d(2, 2)
        # 全結合層
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

4. 順伝播(forward)

 入力画像から出力スコア(ロジット)を計算する一連の処理を定義します。
流れ
画像

畳み込み + ReLU

プーリング

畳み込み + ReLU

プーリング

Flatten

全結合

クラススコア

    def forward(self, x):
        x = torch.relu(self.conv1(x))   # (B, 32, 28, 28)
        x = self.pool(x)                # (B, 32, 14, 14)
        x = torch.relu(self.conv2(x))   # (B, 64, 14, 14)
        x = self.pool(x)                # (B, 64, 7, 7)
        x = x.view(x.size(0), -1)       # flatten
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

5. 損失関数・最適化手法

 モデルの予測と正解ラベルの誤差を定量化し、パラメータ更新方法を定義します。

  • CrossEntropyLoss
    • softmax + log + NLL をまとめたもの
  • Adam
    • 学習率を自動調整
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

6. 学習ループ

 誤差逆伝播法によりネットワークの重みを反復的に更新します。

  • 各epochでやっていること
    • 順伝播
    • 損失計算
    • 逆伝播(勾配計算)
    • パラメータ更新
for epoch in range(10):
    model.train()
    for x, t in train_loader:
        x, t = x.to(device), t.to(device)

        optimizer.zero_grad()
        y = model(x)
        loss = criterion(y, t)
        loss.backward()
        optimizer.step()

    # 検証
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, t in test_loader:
            x, t = x.to(device), t.to(device)
            y = model(x)
            pred = y.argmax(dim=1)
            correct += (pred == t).sum().item()
            total += t.size(0)

    acc = correct / total
    print(f"Epoch {epoch+1}: Test Accuracy = {acc:.4f}")

※ Test Accuracyが出力される
スクリーンショット 2026-01-05 19.34.56.png

7. 結果の可視化(正解・不正解)

 テストデータを用いて汎化性能を評価し、正解・不正解例を可視化してモデルの振る舞いを分析します。

import matplotlib.pyplot as plt
import numpy as np

def collect_correct_incorrect(model, dataloader, device, max_samples=20):
    model.eval()

    correct = []
    incorrect = []

    with torch.no_grad():
        for x, t in dataloader:
            x, t = x.to(device), t.to(device)
            y = model(x)
            pred = y.argmax(dim=1)

            for i in range(x.size(0)):
                img = x[i].cpu().squeeze().numpy()
                true = t[i].item()
                p = pred[i].item()

                if true == p and len(correct) < max_samples:
                    correct.append((img, true, p))
                elif true != p and len(incorrect) < max_samples:
                    incorrect.append((img, true, p))

                if len(correct) >= max_samples and len(incorrect) >= max_samples:
                    return correct, incorrect

    return correct, incorrect

def show_correct_incorrect(correct, incorrect, n=10):
    fig = plt.figure(figsize=(12, 4))

    # 正解例
    for i, (img, t, p) in enumerate(correct[:n]):
        ax = fig.add_subplot(2, n, i + 1)
        ax.imshow(img, cmap="gray")
        ax.set_title(f"✓ T:{t} P:{p}", fontsize=9)
        ax.axis("off")

    # 不正解例
    for i, (img, t, p) in enumerate(incorrect[:n]):
        ax = fig.add_subplot(2, n, n + i + 1)
        ax.imshow(img, cmap="gray")
        ax.set_title(f"✗ T:{t} P:{p}", fontsize=9)
        ax.axis("off")

    plt.suptitle("MNIST CNN Classification Results")
    plt.tight_layout()
    plt.show()

correct, incorrect = collect_correct_incorrect(
    model,
    test_loader,
    device,
    max_samples=20
)

show_correct_incorrect(correct, incorrect, n=10)

※ 予測値Pと正解値Tの正誤画像
スクリーンショット 2026-01-05 19.37.27.png

最後に

 分類結果の画像を見ると人間でも判別がなかなか難しいものは誤りがあるが、例えば傾いた数字の9など分類できています。
これは非常に良いモデル挙動と言えるかと思います。入力をX次元のベクトルと見るMLPに比べると、正確に判別できているようです。

MLPとCNNの決定的な違いを表にまとめます。

観点 MLP CNN
入力 ベクトル 画像
空間構造 失う 保つ
崩れ耐性 弱い 強い
特徴 人任せ 自動抽出

参考文献、リンク

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?