LoginSignup
2

More than 5 years have passed since last update.

PyTorch公式チュートリアル Deep Learning with PyTorch #4 Training a classifier

Posted at

分類器を学習させる

CIFAR10を使って実際にPyTorchで認識を行う分類器を構築する。今回の実装のgistはここ

What about data?

画像や文章、音声やビデオなどのデータを扱う場合、まずデータをnumpy arrayなどに読み込み、torch.*Tensorに変換する。画像に対してはPillow, OpenCV、音声に対してはscipylibrosa、文章に対してはPythonのデータローディングパッケージやNLTKSpaCyなどを使うと良い。

PyTorchでは特にコンピュータビジョンのためのパッケージ、torchvisionがあり、
CIFAR10やImagenetなど一般的なデータセットのデータローダーがある。
このチュートリアルではCIFAR10のデータセットを用いる。CIFAR10の入力データRGBの3チャネルの縦横が32ピクセルの3x32x32となっている。
CIFAR10

Training an image classifier(画像の分類器を学習させる)

分類器の学習のためには以下のようなステップを踏む必要がある。
1. CIFAR10の学習データセットとテストデータセットをtorchvisionで読み込み、正規化する。
2. CNNを定義する。
3. 損失関数を定義する。
4. 学習データを使ってネットワークを定義する。
5. テストデータを使ってテストを行う。

1. CIFAR10を読み込み、正規化する

torchvisionを使うことで、簡単にデータセットを読み込むことができる。torchvisionデータセットの出力はPILImageの[0,1]の画像であり、これを[-1,1]二世聞かされたのTensorに変換する。

import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

このスクリプトを実行するとcifar-10のデータセットが./dataいかにダウンロードされる。

2. CNNを定義する

CIFAR 10についてはRGBの3チャネルの画像となるため、チュートリアル#3の一つ目の畳み込み層を微修正して実装する。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

3. 損失関数及びオプティマイザを定義する

今回はMomentumつきSGDをオプティマイザとする。
SGDには谷間をナビゲートする上で問題がある(谷間とは、表面がカーブしているエリアのことで、ある次元では他の次元よりカーブがもっと急峻になっている部分)。Momentumは関連性のある方向へSGDを加速させ振動を抑制する方法である1。SGDに慣性化項を付与したもの。
重み$\bf{w}$を更新する場合は、$\bf{E}$を誤差関数、$\eta$を学習率とすると、
$$\bf{w^{t+1}} \leftarrow \bf{w^{t}} - \eta \frac{\partial E(\bf{w^t} )}{\partial \bf{w^t}}$$

で重みが更新される2
一方MomentumつきSGDは慣性項のパラメータを$\alpha$とすると、前回の更新量を$\alpha$倍して加算することでパラメータの更新をより慣性的にする。この手法はパラメータが二つあるため、最適化が難しいとも言われる。
 
$$\bf{w^{t+1}} \leftarrow \bf{w^{t}} - \eta \frac{\partial E(\bf{w^t} )}{\partial \bf{w^t}} + \alpha \Delta \bf{w^t}$$

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) # momentumを引数として渡すとmomentumつきSGDに

4. ネットワークを学習させる

データをfor loopによって繰り返し読み込み、入力をネットワークにフィードして最適化する。

# 学習を2エポック回す(データセットを2回繰り返す)
for epoch in range(2):
    running_loss = 0.0
    # 学習データセットを順に読み込んで行く。
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = Variable(inputs), Variable(labels)

        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.data[0]
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

5. ネットワークをテストデータでテストする

ネットワークに入力を渡した際の出力は10次元のベクトルとなり、これがそれぞれのラベルの予測確率になるため、この予測確率が最大となるものが予測結果となる。

correct = 0
total = 0
for data in testloader:
    images, labels = data
    outputs = net(Variable(images))
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
for data in testloader:
    images, labels = data
    outputs = net(Variable(images))
    _, predicted = torch.max(outputs.data, 1)
    c = (predicted == labels).squeeze()
    for i in range(4):
        label = labels[i]
        class_correct[label] += c[i]
        class_total[label] += 1

for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

実際に2エポックを回すとだいたいそれぞれのクラスに対して60~70%程度の精度で予測できるようになる。

Accuracy of plane : 62 %
Accuracy of   car : 65 %
Accuracy of  bird : 37 %
Accuracy of   cat : 26 %
Accuracy of  deer : 35 %
Accuracy of   dog : 55 %
Accuracy of  frog : 65 %
Accuracy of horse : 71 %
Accuracy of  ship : 63 %
Accuracy of truck : 69 %

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2