0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

教師あり学習:画像の分類

Last updated at Posted at 2024-04-14

はじめに

前回は教師あり学習の方法を用いて、ワインのデータセットで分類を実装しました。
前回はこちらから

今回は画像の分類をやります。

MNISTという有名なデータセットで画像の分類をします。

MNISTとは?

MNIST(えむにすと)とは、手書きの数字の画像が大量に入ってるデータセットです。郵便番号の数字が集められています。

image.png
中身はこのようになっています。
データの一つ一つは $28 \times 28$ ピクセルの画像と 0 ~ 9 のラベル情報がペアとなっています。

実行環境

Google Colabでの実行を想定しています。詳しくはこちらを参照してください。

実装

ライブラリの読み込み

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
import torchvision.datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.utils as utils
import numpy as np
import matplotlib.pyplot as plt

PyTorch, numpy, matplotlibを使います。

データセットのダウンロード

train_data = torchvision.datasets.MNIST(root="\data",
                                        download=True,
                                        train=True,
                                        transform=transforms.ToTensor())
test_data = torchvision.datasets.MNIST(root="\data",
                                       download=True,
                                       train=False,
                                       transform=transforms.ToTensor())

MNISTはよく機械学習で用いられるので、PyTorchに処理が用意されています。
torchvision.datasets.MNIST() という関数を使います。MNISTのデータを持ってきてくれる関数です。

引数の説明

root="\data"
データを保存するディレクトリを指定する引数です。
"\data" は任意のディレクトリのパスを指定できます。ここでは \data としています。
このコードを動かすと、勝手に data ディレクトリが作られます。

download=True
ダウンロードをするか否かを指定する引数です。
True であればダウンロードします。
ただし root="\data" で指定したディレクトリに、すでにMNISTのデータが入ってる場合はダウンロードしません。

train = True
学習用データとテスト用データの選択します。
True であれば学習用データ、
False であればテスト用データです。

transform = transforms.ToTensor()
Tensor型に変換しています。

データの読み込み

train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                           batch_size=64,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data,
                                          batch_size=64,
                                          shuffle=False)

torch.utils.data.DataLoader() でデータ取り出しの準備をします。

引数の説明

dataset=train_data
読み込むデータの指定をします。train_loader であれば train_data を読み込みます。

batch_size=64
バッチサイズの指定をします。64としました。

shuffle=True
True であればデータをランダムに取り出します。
テスト用の方は結果が変わらないので False とします。

学習モデルの定義

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 32)
        self.fc2 = nn.Linear(32, 16)
        self.fc3 = nn.Linear(16, 10)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        
        return x
    
net = Net()

学習に用いるニューラルネットワークのモデルを定義します。

入力層は画像サイズを指定します。MNISTの画像サイズは $28 \times 28$ なのでそうなっています。
中間層は適当な値を設定しています。結果を観察しつつ調整します。
出力層は10を指定します。推定するラベルが0~9の10種類であるためです。
活性化関数には ReLU を使います。
net = Net() でインスタンス化しています。

誤差関数と最適化手法

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.005)

criterion = nn.CrossEntropyLoss()
誤差関数は交差エントロピーを使います。

optimizer = optim.SGD(net.parameters(), lr=0.005)
最適化手法は SGD( 確率的勾配降下法 )を使います。
学習率( lr )は 0.005 です。

最適化手法は畳み込みニューラルネットワーク( CNN )のときは、適応モーメント推定( Adam )を使うこともあります。
CNNは次回やります。

学習

for epoch in range(10):
    total_loss = 0
    for train_x, label in train_loader:
        
        train_x = train_x.reshape(-1, 28*28)
        
        optimizer.zero_grad()
        loss = criterion(net(train_x), label)
        loss.backward()
        optimizer.step()
        total_loss += loss.data
    
    if(epoch + 1) % 1 == 0:
        print(epoch + 1, total_loss)

エポック数は10としました。

エポック数とは?

同じデータを何回使いまわして学習するか、です。MNISTの場合は60000個のデータがあるので、60000のデータを10回学習します。
エポック数が少なすぎると学習が途中で終わる可能性があります。
逆に多すぎると、計算時間が増えることに加え、 過学習 となる場合があります。

過学習 ... システムが学習データに最適化され、本番に弱いシステムとなってしまう現象

train_x = train_x.reshape(-1, 28*28) で $28 \times 28$ の2次元配列から $1 \times 784$ の1次元配列にします。計算を行う際は1次元配列でないとエラーを返されます。

正解率の計算

net.eval()

with torch.no_grad():
    correct = 0
    total = 0
    for train_x, train_label in train_loader:
        train_x = train_x.reshape(-1, 28*28)
        outputs = net(train_x)
        _, predicted = torch.max(outputs.data, 1)
        total += train_label.size(0)
        correct += (predicted == train_label).sum().item()

    train_accuracy = correct / total

with torch.no_grad():
    correct = 0
    total = 0
    for test_x, test_label in test_loader:
        test_x = test_x.reshape(-1, 28*28)
        outputs = net(test_x)
        _, predicted = torch.max(outputs.data, 1)
        total += test_label.size(0)
        correct += (predicted == test_label).sum().item()

    test_accuracy = correct / total

print(f"学習用データの正解率:{100 * train_accuracy:.2f}%")
print(f"テスト用データの正解率:{100 * test_accuracy:.2f}%")

学習データと検証データで正解率を計算しています。

結果

スクリーンショット (994).png

正解率は90%程度でした。

確認

net.eval()

predicted_labels = []
true_labels = []

with torch.no_grad():
    for i, (test_x, test_label) in enumerate(test_loader):
        test_x = test_x.reshape(-1, 28*28)
        output = net(test_x)
        _, predicted = torch.max(output, 1)
        predicted_labels.extend(predicted.tolist())
        true_labels.extend(test_label.tolist())
        
        if i == 9:
            break

print("予想:", predicted_labels[:10])
print("正解:", true_labels[:10])

スクリーンショット (995).png

どうやら 5 の画像を 6 と見間違えているようです。

画像と見比べる

num_images_per_row = 10

fig, axes = plt.subplots(nrows=1, ncols=num_images_per_row, figsize=(10, 2))

for idx, (images, labels) in enumerate(test_loader):
    if idx < 1:
        for i, image in enumerate(images):
            ax = axes[i]
            ax.imshow(np.transpose(image, (1, 2, 0)), cmap='gray', vmin=0, vmax=1)
            ax.set_title("label:" + str(predicted_labels[i]))
            ax.axis('off')

            if i == num_images_per_row - 1:
                break
    else:
        break

plt.tight_layout()
plt.show()

スクリーンショット (996).png

たしかに、5 の画像6 と判定しています。それ以外は正解しているので、正解率90%は妥当な数値でしょう。

おわりに

MNISTという機械学習では有名なデータセットを用いて画像の分類を実装しました。
次回は正解率をさらに上げるべく、畳み込みニューラルネットワーク( CNN )を実装します。

次回

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?