LoginSignup
27
34

More than 3 years have passed since last update.

PyTorchによる画像分類チュートリアル

Last updated at Posted at 2019-06-15

はじめに

本記事ではPyTorch公式を参考に、機械学習ライブラリPyTorchによる画像分類のチュートリアルを行います。

前提条件

・PyTorchインストール済み
・Jupyterで行うことを想定しNotebook形式で進める(script形式でも問題はない)

流れ

1.torchvisionを用いてCIFAR10のデータセットを読み込む
2.CNN(畳み込みニューラルネットワーク)の定義
3.損失関数の定義
4.訓練データを使って学習
5.テストデータを使って予測

1.CIFAR10のデータセットを読み込む

PyTorchにはtorchvisionと呼ばれるライブラリが含まれており、機械学習の画像データセットとしてよく使われているImagenet, CIFAR10, MNISTなどが利用できます。
今回のチュートリアルではCIFAR10のデータセットを利用します。
はじめに以下をインポートします。

import torch
import torchvison
import torchvision.transforms as transform

torchvisionから得られるデータセットは[0, 1]の範囲で表されるPIL画像で出力されます。
今回はこれらを学習のために[-1, 1]の範囲に正規化した行列に変換します。

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])  # Normalize(平均, 偏差)

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

データセットのダウンロードに多少時間がかかりますが、終了すると以下のように表示が出ます。

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz

Files already downloaded and verified

読み込んだ画像とラベルのデータを試しに表示させてみます。

import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline # Notebook形式の場合画像を表示させるのに必要

def imshow(img):
    img = img / 2 + 0.5     
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# 訓練データをランダムに取得
dataiter = iter(trainloader)
images, labels = dataiter.next()

# 画像の表示
imshow(torchvision.utils.make_grid(images))
# ラベルの表示
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

これを実行すると下記のように表示されます。
上が読み込んだ画像のサンプル、下が画像に対応するラベルです。
image.png

dog  ship   cat  deer

2.CNN(畳み込みニューラルネットワーク)の定義

データの用意ができたので次に画像認識において有効なニューラルネットワークであるCNN(Convolutional Neural Network)のをpytorch内のライブラリを用いて定義します。

# ニューラルネットワーク(NN)を構成する際に使用するライブラリ
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    # NNの各構成要素を定義
    def __init__(self):
        super(Net, self).__init__()

        # 畳み込み層とプーリング層の要素定義
        self.conv1 = nn.Conv2d(3, 6, 5)  # (入力, 出力, 畳み込みカーネル(5*5))
        self.pool = nn.MaxPool2d(2, 2)  # (2*2)のプーリングカーネル
        self.conv2 = nn.Conv2d(6, 16, 5)
        # 全結合層の要素定義
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # (入力, 出力)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)  # クラス数が10なので最終出力数は10

    # この順番でNNを構成
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # conv1->relu->pool
        x = self.pool(F.relu(self.conv2(x)))  # conv2->relu->pool
        x = x.view(-1, 16 * 5 * 5)  # データサイズの変更
        x = F.relu(self.fc1(x))  # fc1->relu
        x = F.relu(self.fc2(x))  # fc2->relu
        x = self.fc3(x)
        return x


net = Net()

今回定義したCNNは、
入力層 → 畳み込み層 → プーリング層 → 畳み込み層 → プーリング層 →
全結合層 → 全結合層 → 全結合層 → 出力層
という構成になっています。

forward関数内にあるview関数は1つ目の引数に-1を入れることで、2つ目の引数で指定した値にサイズ数を自動的に調整してくれます。

3.損失関数とオプティマイザの定義

損失関数とオプティマイザは以下の関数で簡単に定義できます。
今回は損失関数に交差エントロピーを、オプティマイザにAdamを使用します。学習率は0.001です。


import torch.optim as optim  #オプティマイザ用のライブラリ

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

ニューラルネットワークの訓練

for epoch in range(2): 

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):

        # 訓練データから入力画像の行列とラベルを取り出す
        inputs, labels = data

        # 勾配パラメータを0にする
        optimizer.zero_grad()

        # 順伝播 → 逆伝播 → 勾配パラメータの最適化
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # 損失関数の変化を2000ミニバッチごとに表示
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')
[1,  2000] loss: 1.334
[1,  4000] loss: 1.326
[1,  6000] loss: 1.344
[1,  8000] loss: 1.308
[1, 10000] loss: 1.296
[1, 12000] loss: 1.285
[2,  2000] loss: 1.194
[2,  4000] loss: 1.195
[2,  6000] loss: 1.192
[2,  8000] loss: 1.204
[2, 10000] loss: 1.191
[2, 12000] loss: 1.189
Finished Training

損失関数が徐々に下がっていってるのがわかります。

5.テストデータによるネットワークのテスト

ランダムに取り出したテストデータの画像とラベルを試しに表示させてみます。

# 訓練データをランダムに取得
dataiter = iter(testloader)
images, labels = dataiter.next()

# 画像の表示
imshow(torchvision.utils.make_grid(images))
# ラベルの表示
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))

image.png

GroundTruth:    cat  ship  ship  plane

訓練したNNを用いて予測を行います。

outputs = net(images) # 訓練後のNNに画像を入力
_, predicted = torch.max(outputs, 1) #入力した画像の行列の最大値(もっとも確率の高いもの)を返す

print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
                              for j in range(4)))
Predicted:    cat  ship  ship  plane

上記の結果から4つの画像全てにおいて正しく予測できていることがわかります。

今回ランダムに取り出した4つの画像は全て正解を予測できていましたが、他の画像ではどうなのかを確かめる必要があります。
そこで10000枚あるテストデータがどの程度正しく予測できているかを確かめてみます。


correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))
Accuracy of the network on the 10000 test images: 58 %

上記より58%の精度で分類が正しく行われていることがわかります。
つまり今回取り出した画像は偶然4つとも正解を予測できていましたが、全体で見ると6割程度の予測精度のようです。

最後に各ラベルごとの予測制度も出してみたいと思います。

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))
Accuracy of plane : 60 %
Accuracy of   car : 71 %
Accuracy of  bird : 56 %
Accuracy of   cat : 45 %
Accuracy of  deer : 44 %
Accuracy of   dog : 51 %
Accuracy of  frog : 61 %
Accuracy of horse : 58 %
Accuracy of  ship : 68 %
Accuracy of truck : 67 %

各ラベルの予測精度は上記のようになりました。

おわりに

今回はPyTorchを勉強する目的で画像分類のチュートリアルを整理しつつ実装してみました。
以前はKerasを使用していましたが、KaggleでPyTorchを使用したのを機に使い始めました。機械学習分野はまだまだ学ぶことが多いので引き続き勉強していきたいと思います。

参考

PyTorch公式

27
34
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
27
34