LoginSignup
12
8

More than 3 years have passed since last update.

【PyTorchチュートリアル④】TRAINING A CLASSIFIER

Last updated at Posted at 2020-04-28

はじめに

前回に引き続き、PyTorch 公式チュートリアル の第4弾です。
今回は TRAINING A CLASSIFIER を進めます。

クラス分類の学習

前回は、ニューラルネットワークを定義して損失を計算し、ネットワークの重みを更新する方法を見ました。
今回は画像データを利用して、分類問題について見ていきます。
PyTorch には 画像データを扱うためのライブラリ torchvision が用意されています。
torchvision には Imagenet、CIFAR-10、MNISTなどの基本的な画像データセットもあらかじめ用意されています。
このチュートリアルでは CIFAR-10 のデータセットを使用します。

CIFAR-10

CIFAR-10 は 32x32 ピクセルのカラー画像でRGBの3チャネルあり、(3,32,32)の形状になっています。
クラスは「飛行機」、「自動車」、「鳥」、「猫」、「鹿」、 「犬」、「カエル」、「馬」、「船」、「トラック」の10種類にラベル分けされています。

cifar10.png

画像分類の学習

CIFAR-10 を利用して学習の流れを見ていきましょう。
画像をクラス分けする学習は、次の手順を順番に実行します。

  1. torchvisionを使用してCIFAR10トレーニングおよびテストデータセットを読み込み、正規化する
  2. 畳み込みニューラルネットワークの定義
  3. 損失関数と最適化アルゴリズムを定義する
  4. トレーニングデータで学習する
  5. テストデータで学習した結果を確認する

1 . CIFAR-10 の読み込みと正規化

PyTorchでデータを読み込むには DatasetDataLoader を使うと便利です。
Dataset は画像と正解ラベル(1つの学習データ)を保持します。
DataLoader 学習データ(テストデータ)を繰り返し取得するためのユーティリティです。

import torch
import torchvision
import torchvision.transforms as transforms

# transform の定義
# ToTensor で Tensorに変換し
# 標準化 ( X - 0.5) / 0.5 
transform = transforms.Compose([
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# トレーニング用データセット
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

# トレーニング用データローダ
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=False, num_workers=2)

# テスト用データセット
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

# テスト用データローダ
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

# クラス分類の一覧
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

先ず、transform でデータセットの変換方法を設定しています。
Compose は複数の変換を順に繋げます。
ここでは、ToTensor と Normalize を実行しています。
ToTensor は画像データを Tensor に変換します。ToTensor では、RGBの値は [0, 1] の fload 型で表されます。
Normalize で ( X - 0.5) / 0.5 を計算し、RGBの範囲を [-1, 1] に変換しています。

以下のコードで一部の画像が表示されます。
trainloader の batch_size が 4 ですので4つの画像がまとめて処理されます。

import matplotlib.pyplot as plt
import numpy as np

# 画像を表示する関数
def imshow(img):
    img = img / 2 + 0.5     # 標準化を戻す
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))


# トレーニングデータから画像を取得する(ランダム)
dataiter = iter(trainloader)
images, labels = dataiter.next()

# 画像を表示
imshow(torchvision.utils.make_grid(images))
# ラベルを表示
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

frog truck truck deer
imshow.png

2 . 畳み込みニューラルネットワークの定義

前回のニューラルネットワークのチュートリアルからニューラルネットワークをコピーし、3チャネルの画像(定義されていた1チャネルの画像ではなく)を取得するように変更します。

import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

3 . 損失関数と最適化アルゴリズムを定義する

分類問題では、主に損失関数は交差エントロピー誤差関数を利用します。
二値分類問題では、torch.nn.BCEWithLogitsLoss
多クラス分類問題では、torch.nn.CrossEntropyLoss
がよく利用されます。
今回は10個のラベル分類(多クラス分類)ですので、CrossEntropyLoss を利用します。
最適化アルゴリズムは、最も基本的な「確率的勾配降下法」(SGD) にします。

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

4 . トレーニングデータで学習する

データローダをループし、トレーニングデータをネットワークに流してパラメータを最適化します。

for epoch in range(2):  # エポック数回分ループ

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # トレーニングデータを取得する
        inputs, labels = data

        # 勾配を初期化する
        optimizer.zero_grad()

        # ニューラルネットワークにデータを通し、順伝播を計算する
        outputs = net(inputs)
        # 誤差の計算
        loss = criterion(outputs, labels)
        # 逆伝播の計算
        loss.backward()
        # 重みの計算
        optimizer.step()

        # 状態を表示する
        running_loss += loss.item()
        if i % 2000 == 1999:    # 2,000 データずつ
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

2分ほどで学習が完了します。

[1,  2000] loss: 2.153
[1,  4000] loss: 1.830
[1,  6000] loss: 1.654
[1,  8000] loss: 1.556
[1, 10000] loss: 1.524
[1, 12000] loss: 1.511
[2,  2000] loss: 1.441
[2,  4000] loss: 1.380
[2,  6000] loss: 1.384
[2,  8000] loss: 1.358
[2, 10000] loss: 1.335
[2, 12000] loss: 1.320
Finished Training

5 . テストデータで学習した結果を確認する

学習した結果をテストデータで確認します。
テストデータの一部を見てみましょう。

dataiter = iter(testloader)
images, labels = dataiter.next()

# print images
imshow(torchvision.utils.make_grid(images))
print('教師データ: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))

教師データ: cat ship ship plane
imshow.png

この4つの画像を学習したモデル(分類器)がどう判定するか見てみます。
学習したネットワークにテストデータを渡すと、結果が返ってきます。

outputs = net(images)
print(outputs[0,:])
tensor([-1.0291, -3.1352,  0.5837,  3.7728, -1.3638,  3.4090,  0.4094,  0.3352,
        -0.6388, -1.4808], grad_fn=<SliceBackward>)

output には各ラベル(10個)の重みが返却されます。

_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
                              for j in range(4)))

max で最大のインデックスを取得します。
max の返却値を2つ指定した場合、1つ目に最大値が、2つ目に最大値のインデックスが返却されます。

Predicted:    cat  ship  ship plane

教師データと一致することが確認できました。

テストデータ全体で結果を確認してみましょう。

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))
Accuracy of the network on the 10000 test images: 52 %

テストデータ 1万件で正解率は 52% でした。
学習できていない場合、1/10 (10%)なので何かは学習できてはいますが、あまり精度は高くありません。
ラベルごとに正解率を見てみます。

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))
Accuracy of plane : 65 %
Accuracy of   car : 51 %
Accuracy of  bird : 33 %
Accuracy of   cat : 22 %
Accuracy of  deer : 39 %
Accuracy of   dog : 64 %
Accuracy of  frog : 66 %
Accuracy of horse : 70 %
Accuracy of  ship : 53 %
Accuracy of truck : 61 %

cat など、あまり学習できていないケースがあるようです。

GPUでの学習

最後に GPU での学習を見ていきます。
GPUが利用できる環境の場合、ニューラルネットワークの学習に GPU を利用することができます。
Google Colaboratory を利用している場合、「ランタイム」⇒「ランタイムのタイプを変更」⇒「GPU」を選択することで GPU を利用することができます。

以下のコードで GPU の利用を確認できます。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Assume that we are on a CUDA machine, then this should print a CUDA device:
print(device)

device に「cuda」であれば GPU を利用できます。
GPU を利用する場合、モデルを GPU に転送するコードを記述する必要があります。

net.to(device)

同様に 学習データも GPU に転送する必要があります。

inputs, labels = inputs.to(device), labels.to(device)

Google Colaboratory でこのチュートリアルの学習にかかった時間を確認したところ
GPUを利用しなかった場合:2分10秒
GPUを利用した場合:1分40秒
でした。
学習用データの次元がそこまで複雑でない場合、効果は薄いようです。

終わりに

以上が、PyTorch の4つ目のチュートリアル「TRAINING A CLASSIFIER」の内容です。
学習の基本的な流れを理解することができました。

次回は5つ目のチュートリアル「LEARNING PYTORCH WITH EXAMPLES」を進めてみたいと思います。

履歴

2020/04/29 初版公開
2020/10/21 次回のリンク追加

12
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
8