LoginSignup
1
1

【PyTorch】CNN(Convolutional Neural Network)備忘録

Last updated at Posted at 2023-06-14

CNNが初めに紹介されたのは、LeNet(Backpropagation applied to handwritten zip code recognition, LeCun et al., 1989)であり、画像の特徴量抽出を畳み込み層とプーリング層で行うことで、画像認識の精度を向上させた。作成者+Netの略称を名付けた初期のモデルである。

CNN-LeNet1.PNG

PyTorchのデータロード、モデル定義、学習、評価、保存、推論の基本を一枚のコードで実装してみる。

PyTorch(CPU版)のインストール

pip3 install torch torchvision torchaudio

Pillowのインストール

pip3 install Pillow

一般的な実装の流れ

  1. データを読み込む
  2. 前処理を行う
  3. モデルを定義する
  4. 損失関数と最適化アルゴリズムを定義する
  5. 教師データを用いてモデルの学習を行う
  6. 学習済みのモデルを保存する
  7. 学習済みのモデルを読み込む
  8. 学習済みのモデルで推論する

Pythonコード全文

# PyTorchを使ってMNISTの手書き数字データセットを用いた畳み込みニューラルネットワークの学習

import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import PIL.ImageOps    
import PIL.Image as pilimg


# データの前処理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# MNISTデータセットの読み込み
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

# DataLoaderの作成
# ミニバッチ学習, 60,000枚の画像データを各バッチは64個のデータで分ける
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

# モデルの定義
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = nn.functional.relu(x)
        x = self.conv2(x)
        x = nn.functional.relu(x)
        x = nn.functional.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = nn.functional.log_softmax(x, dim=1)
        return output

# モデルのインスタンス化
model = Net()

# ハイパーパラメータの設定
learning_rate = 0.01
epochs = 10

# 損失関数と最適化アルゴリズムの定義
# 交差エントロピー損失関数
criterion = nn.CrossEntropyLoss()
# 確率的勾配降下法(Stochastic Gradient Descent, SGD)
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
# 全データ使用するバッチ勾配降下法
# いくつかのデータの誤差の和を使うミニバッチ勾配降下法
# いくつかのデータの誤差ごとを使う確率的勾配降下法(オンライン勾配降下法)

# 学習ループ
for epoch in range(epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        # 入力データとラベルの取得
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

    # 1エポックごとにテストデータでモデルを評価
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print('Epoch: {} Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(epoch, test_loss, correct, len(test_loader.dataset), accuracy))

# モデルのパラメータのみを保存
torch.save(model.state_dict(), 'model_weight.pth')
# モデル全体を保存
torch.save(model, 'model.pth')


# モデルを読み込む
model = Net()
model.load_state_dict(torch.load('model_weight.pth'))

# 例:手書き数字の画像を読み込む
image = pilimg.open("test.png").convert('L')
image = PIL.ImageOps.invert(image)
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
image = transform(image).unsqueeze(0)

with torch.no_grad():
    output = model(image)
    # 出力の中で最大の値を持つインデックスを取得
    prediction = output.argmax(dim=1, keepdim=True)
    print("\nPrediction:", prediction.item())

実行結果

下記の画像をtest.pngで保存して、推論する。
test.png

結果
Prediction: 5

正しく推論できた。

おまけ:MNISTの手書き文字画像のダウンロード方法

# MNISTデータをPNGで保存
import os
from torchvision import datasets

# 保存先フォルダ設定
rootdir = "MNIST"
traindir = rootdir + "/train"
testdir = rootdir + "/test"

# MNIST データセット読み込み
train_dataset = datasets.MNIST(root=rootdir, train=True, download=True)
test_dataset = datasets.MNIST(root=rootdir, train=False, download=True)

# 画像保存 train
number = 0
for img, label in train_dataset:
    savedir = traindir + "/" + str(label)
    os.makedirs(savedir, exist_ok=True)
    savepath = savedir + "/" + str(number).zfill(5) + ".png"
    img.save(savepath)
    number = number + 1
    print(savepath)

# 画像保存 test
number = 0
for img, label in test_dataset:
    savedir = testdir + "/" + str(label)
    os.makedirs(savedir, exist_ok=True)
    savepath = savedir + "/" + str(number).zfill(5) + ".png"
    img.save(savepath)
    number = number + 1
    print(savepath)

まとめ

今回は、MNISTの手書き数字のPyTorchによる学習・モデル保存・推論を紹介した。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1