1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

MNIST手書き数字のCNN画像認識

Last updated at Posted at 2022-12-28

MNIST手書き数字のCNN画像認識(畳み込み)は至る所にありますが、ここでは PyTorchチュートリアル を紹介したいと思います。
PyTorchチュートリアル/[6] torch.nnの解説

【関連記事】
MNIST手書き数字のCNN画像認識 - Qiita
CNN 畳み込み層のメモ - Qiita
Softmax+CrossEntropy の実装 - Qiita
機械学習のウォーミングアップ(Numpy) - Qiita

1. Data の取得と処理

  • Python の pathlibrequests、pickle を使って Data の取得・読み込みを行います。
  • 次に読み込んだ Numpy配列をテンソル化します。
  • 最後にTensorDataset と DataLoader で fit の繰り返し処理に対応しておきます。
# Data のダウンロード
from pathlib import Path
import requests

DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"
PATH.mkdir(parents=True, exist_ok=True)
URL = "https://github.com/pytorch/tutorials/raw/master/_static/"
FILENAME = "mnist.pkl.gz"
if not (PATH / FILENAME).exists():
        content = requests.get(URL + FILENAME).content
        (PATH / FILENAME).open("wb").write(content)

# Data の読み込み
import pickle
import gzip
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
    ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

# Numpy配列 Data のテンソル化
import torch
x_train, y_train, x_valid, y_valid = map(
    torch.tensor, (x_train, y_train, x_valid, y_valid)
)

# TensorDataset と DataLoader で fit の繰り返し処理に対応
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
train_ds = TensorDataset(x_train, y_train)
valid_ds = TensorDataset(x_valid, y_valid)

# train と valid の DataLoader を同時取得
def get_data(train_ds, valid_ds, bs):
    return (
        DataLoader(train_ds, batch_size=bs, shuffle=True),
        DataLoader(valid_ds, batch_size=bs * 2),
    )

2. 損失関数

損失関数 F.cross_entropy 関数 は以下の記事における log_softmax 関数nil 関数を組み合わせた処理になります。是非、以下の記事を参照してください。
Softmax+CrossEntropy の実装 - Qiita

import torch.nn.functional as F
loss_func = F.cross_entropy

3. fit 関数

fit 関数 は訓練(学習)の繰り返し処理を行う部分を切り出したものです。

まず fit 関数 を簡潔にするために、損失計算と勾配計算の部分を loss_batch 関数 として切り出しておきます。

def loss_batch(model, loss_func, xb, yb, opt=None):
    loss = loss_func(model(xb), yb)

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item(), len(xb)

訓練(学習)の繰り返し処理を行う fit 関数 です。

import numpy as np

def fit(epochs, model, loss_func, opt, train_dl, valid_dl):
    for epoch in range(epochs):
        model.train()
        for xb, yb in train_dl:
            loss_batch(model, loss_func, xb, yb, opt)

        model.eval()
        with torch.no_grad():     # 訓練でない時は、計算履歴を保持しない。
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)

        print(epoch, val_loss)

model.train() と model.eval() でトレーニングモードと予測モードの切り替えを行っています。モードによって処理内容が異なる Pytorch 関数 があるので必要となります。

[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]

上のリスト内包表記の結果は以下のようなリストになります。
$[ (loss_1, num_1), (loss_2, num_2), ...., (loss_n, num_n) ]$
但し、n は validデータの総数をバッチサイズで割った数とします。(正確にはバッチサイズの2倍で割る)。

このリストをunpackしzipして、loss と num のタプルを取得しています。loss はもともとバッチの平均値だったので、バッチサイズのnumを掛けて、その総和を求めることで、valid データ全体の loss の総和を求めています。それを valid データ全体の長さで割っています。

unpack と zip の動作は以下のテストを確認してください。タプルはそのまま numpy で処理できるようですね。

>>> import numpy as np
>>> test_list = [ (1, 11), (2, 12), (3, 13) ]
>>> x,y = zip(*test_list)
>>> x
(1, 2, 3)
>>> y
(11, 12, 13)
>>> np.multiply(x,y)
array([11, 24, 39])

4. Model の定義

ここではCNN バージョンのみ示していますが、チュートリアルでは、線形 model を試してから、 CNN model を試しています。model の定義を変更するだけで、その他の処理はそのまま使えるというストーリで記述されています。
PyTorchチュートリアル/[6] torch.nnの解説

畳み込み処理の出力の高さ・横幅
一般的に H, W を入力の高さ・横幅、 stride=S, padding=P とすれば

OH = (H + 2P - FH)/S + 1
OW = (W + 2P - FW)/S + 1

例えば、H=28, stride=2, padding=1の場合は以下の通りとなる。

OH = (H + 2P - FH)/S + 1 = (H+2-3)/2 +1 = (H-1)/2 + 1 =(28-1)/2 +1 = 14

以下が model の定義です。それぞれ 畳み込み処理プーリング処理出力の高さ・横幅 の計算結果をコメントにしてみました。

class Mnist_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1)  # OH = 14
        self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1) # OH = 7
        self.conv3 = nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1) # OH = 4

    def forward(self, xb):
        xb = xb.view(-1, 1, 28, 28)  # [64, 1, 28, 28]
        xb = F.relu(self.conv1(xb))  # [64, 16, 14, 14]
        xb = F.relu(self.conv2(xb))  # [64, 16, 7, 7]
        xb = F.relu(self.conv3(xb))  # [64, 10, 4, 4]
        xb = F.avg_pool2d(xb, 4)     # [64, 10, 1, 1]
        return xb.view(-1, xb.size(1))  # [64, 10]

F.relu活性化関数 ReLU ですが、定義は以下の記事を参照してください。
機械学習のウォーミングアップ(Numpy) - Qiita

最初に xb.view(-1, 1, 28, 28) で入力データを4次元に変更しています。これは CNN のデータの流れが4次元で行われることに対応しています。入力データはもともと (N, 784) の形で流れてくるので、線形model ではそのまま扱いますが、CNN model ではここで (N, 1, 28, 28) に変形してCNN のネットワークに流しています。

この Model の特徴は線形関数を含んでいないことです。このModelの最終目的は出力を [64, 10] にして、10個の分類を行うことです。 そのために、線形関数の classfier を用意して調整することがあります。しかしここでは Conv2d と avg_pool2d を巧みに使って、出力データを [64, 10, 1, 1] に落とし込み、最後に view(-1, xb.size(1)) で [64, 10] にすることで 特別な classfier を用いずに済ませています。

5. Model のインスタンスと最適関数

Model のインスタンスと最適関数の取得を同時に行えるようにします。Modelで使用するパラメータは model.parameters() で取得できるので、それに対して最適化関数を適用します。momentum を指定することで直近の勾配地だけでなく、過去の勾配地も考慮するようにします。

lr = 0.1
def get_model():
    model = Mnist_CNN()
    return model, optim.SGD(model.parameters(), lr=lr, momentum=0.9)

6. 訓練

ここまでの準備で、実質以下の3行で訓練(学習)を行うことができるようになりました。

bs = 64  # batch size
epochs = 10

train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(epochs, model, loss_func, opt, train_dl, valid_dl)

実行結果です。損失値がゼロに近づいているのがわかります。おおむね成功ですね。

0 0.36378984227180483
1 0.267542066693306
2 0.21404471072554587
3 0.1797706115782261
4 0.19904899366497994
5 0.15275968505442142
6 0.1643963666319847
7 0.15317361530661583
8 0.13119296333193778
9 0.13225660784840584

今回は以上です

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?