LoginSignup
1
1

教師あり学習:畳み込みニューラルネットワーク

Last updated at Posted at 2024-04-14

はじめに

前回 教師あり学習:画像の分類 の続きです。

前回はMNISTという手書き数字のデータセットで教師あり学習を実装しました。
結果、正解率が90%ほどとなりました。

今回は正解率をさらに上げることが目標です。

畳み込みニューラルネットワーク

畳み込みニューラルネットワーク( CNN )とは、フィルターとよばれる特徴画像を生成し、その特徴画像で学習を行います。
画像の分類には、このような処理を行うと正解率が上がるといわれています。

前回のようなモデルを MLP( 多層パーセプトロン ) といいます。
今回の畳み込みニューラルネットワーク( CNN )は、
畳み込み処理 + MLP のようなイメージです。

実装

実行環境

Google Colabでの実行を想定しています。詳しくはこちらを参照してください。

ライブラリの読み込み

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
import torchvision.datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.utils as utils
import numpy as np

使用するライブラリは前回と同じです。

データの準備

train_data = torchvision.datasets.MNIST(root="\data",
                                        download=True,
                                        train=True,
                                        transform=transforms.ToTensor())
test_data = torchvision.datasets.MNIST(root="\data",
                                       download=True,
                                       train=False,
                                       transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                           batch_size=64,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data,
                                          batch_size=64,
                                          shuffle=False)

データの準備も前回と同じです。前回ダウンロードしている場合、ダウンロードは行われません。

CUDAの確認

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

畳み込みニューラルネットワークの計算は少し大変ですので、CUDAが使える環境ではCUDAを使用します。

学習モデルの定義

class CNN(nn.Module):   
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3)
        self.conv2 = nn.Conv2d(16, 32, 3)

        self.pool = nn.MaxPool2d(2, 2)

        self.fc1 = nn.Linear(32 * 5 * 5, 32)
        self.fc2 = nn.Linear(32, 16)
        self.fc3 = nn.Linear(16, 10)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))

        x = x.view(-1, 32 * 5 * 5)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
net = CNN()
net.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.005)

フィルターサイズを $3 \times 3$ にしました。
畳み込み、プーリングをそれぞれ2回ずつ行います。
optimizerAdam に変更しました。

学習

for epoch in range(5):
    total_loss = 0
    for train_x, label in train_loader:
        
        train_x, label = train_x.to(device), label.to(device)
        
        optimizer.zero_grad()
        loss = criterion(net(train_x), label)
        loss.backward()
        optimizer.step()
        total_loss += loss.data
        
        train_x, label = train_x.to("cpu"), label.to("cpu")
    
    if(epoch + 1) % 1 == 0:
        print(epoch + 1, total_loss)
        
net.to("cpu")

前回との違いは、reshape の処理をしていません。畳み込みやその他諸々の処理をクラスの中で定義しているため、ここではそのままのデータを渡します。

評価

net.eval()

with torch.no_grad():
    correct = 0
    total = 0
    for train_x, train_label in train_loader:
        outputs = net(train_x)
        _, predicted = torch.max(outputs.data, 1)
        total += train_label.size(0)
        correct += (predicted == train_label).sum().item()

    train_accuracy = correct / total

with torch.no_grad():
    correct = 0
    total = 0
    for test_x, test_label in test_loader:
        outputs = net(test_x)
        _, predicted = torch.max(outputs.data, 1)
        total += test_label.size(0)
        correct += (predicted == test_label).sum().item()

    test_accuracy = correct / total

print(f"学習用データの正解率:{100 * train_accuracy:.2f}%")
print(f"テスト用データの正解率:{100 * test_accuracy:.2f}%")

スクリーンショット (997).png

前回に比べて大分正解率が上がりました。

確認

net.eval()

predicted_labels = []
true_labels = []

with torch.no_grad():
    for i, (test_x, test_label) in enumerate(test_loader):
        output = net(test_x)
        _, predicted = torch.max(output, 1)
        predicted_labels.extend(predicted.tolist())
        true_labels.extend(test_label.tolist())
        
        if i == 9:
            break

print("予想:", predicted_labels[:10])
print("正解:", true_labels[:10])

スクリーンショット (1000).png
先頭の10個のデータはすべて正解しています。

import matplotlib.pyplot as plt

num_images_per_row = 10

fig, axes = plt.subplots(nrows=1, ncols=num_images_per_row, figsize=(10, 2))

for idx, (images, labels) in enumerate(test_loader):
    if idx < 1:
        for i, image in enumerate(images):
            ax = axes[i]
            ax.imshow(np.transpose(image, (1, 2, 0)), cmap='gray', vmin=0, vmax=1)
            ax.set_title("label:" + str(predicted_labels[i]))
            ax.axis('off')

            if i == num_images_per_row - 1:
                break
    else:
        break

plt.tight_layout()
plt.show()

スクリーンショット (999).png

前回は間違えていた 5 の画像 も正解しています。

おわりに

前回に比べて計算が大変になりますが、確かに正解率が上がったように感じます。
次回は画像のノイズを取り除く処理( オートエンコーダ )を実装します。

次回

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1