LoginSignup
15
17

More than 3 years have passed since last update.

脳死で覚えるPyTorch入門

Last updated at Posted at 2019-04-03

前書き

全てのプログラマーは写経から始まる。 by俺

この記事は機械学習入門用ではありません。

良質な写経元を提供するためにあります。

無駄のないコード無駄のない説明を用意したつもりです。

PyTorchコーディングを忘れかけた時に立ち返られる原点となれば幸いです。

実行環境

  • python (3.7.2)
  • PyTorch (1.0.1)

コード全文

PyTorch等のインストールを済ませ、僕と同じバージョンならコピペすれば動くはずです。

一つ一つ解説していきます。

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


class MNISTConvNet(nn.Module):

    def __init__(self):
        super(MNISTConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, kernel_size=5, stride=1)
        self.conv2 = nn.Conv2d(20, 50, kernel_size=5, stride=1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = F.log_softmax(x, dim=1)
        return x

if __name__ == "__main__":

    # hyper param
    epochs = 20
    batch_size = 128

    # model
    model = MNISTConvNet()

    # data
    transform = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),
         torchvision.transforms.Normalize((0.5, ), (0.5, ))]
    )

    trainset = torchvision.datasets.MNIST(root='~/datasets', train=True, download=True, transform=transform)
    testset = torchvision.datasets.MNIST(root='~/datasets', train=False, download=True, transform=transform)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

    # loss
    criterion = nn.CrossEntropyLoss()
    # optimizer
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    # train
    for epoch in range(epochs):
        running_loss = 0.0

        for i, data in enumerate(trainloader):
            inputs, labels = data

            optimizer.zero_grad()

            outputs = model(inputs)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 100 == 99:
                print('Epoch:{}/{} loss: {:.3f}'.format(epoch + 1, epochs, running_loss / 100))
                running_loss = 0.0

    print('Finished Training')

解説

import

torchvisionってのは画像変換によく使われるパッケージ。

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

モデル定義

PyTorch公式のexampleを参考にしました。

class MNISTConvNet(nn.Module):

    def __init__(self):
        super(MNISTConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, kernel_size=5, stride=1)
        self.conv2 = nn.Conv2d(20, 50, kernel_size=5, stride=1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = F.log_softmax(x, dim=1)
        return x

__init__forwardを実装すればOKです。当然他にも機能を追加していってもOKです。

__init__の中に使用する層を書きます。

forwardの中にモデルを設計します。

よく見ると、relumax_pool2dなどは__init__に書かれていないことがわかります。どうやら、層は__init__に書いて、関数は__init__に書かないというスタンスがあるみたいです。

ハイパーパラメータ

ここから下は全てメインルーティングです。

# hyper param
epochs = 20
batch_size = 128

説明略

モデルインスタンスの作成

# model
model = MNISTConvNet()

説明略

(補足)

print(model)

とすると、

MNISTConvNet(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)

という風にモデルのsummaryを出力してくれます。

データ

PyTorchがデフォルトで用意してくれているデータセットを使用します。

データセットをダウンロードするpathはrootで指定できますので、お好みの場所に変更してください。

# data
transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(),
     torchvision.transforms.Normalize((0.5, ), (0.5, ))]
)

trainset = torchvision.datasets.MNIST(root='~/datasets', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='~/datasets', train=False, download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

1行目でtransformというオブジェクトを生成していると思います。Composeに正規化などを事前に指定しておきます。そして、2行目3行目の際にtransformオブジェクトを渡すことによって、正規化などが施されたデータをロードしてくれます。

4行目5行目でtrainloadertestloaderというオブジェクトを生成していると思います。このtransloaderは、あとでfor文のイテレータとしてぶち込まれるもので、(inputs, labels)というタプルをイテレートしてくれます。

loss関数と最適化アルゴリズムの設定

当然、自作関数を設定することもできます。

自作関数を設定したかったらググりましょう。(頑張れ)

# loss
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

学習

# train
for epoch in range(epochs):
    running_loss = 0.0

    for i, data in enumerate(trainloader):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = model(inputs)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:
            print('Epoch:{}/{} loss: {:.3f}'.format(epoch + 1, epochs, running_loss / 100))
            running_loss = 0.0

print('Finished Training')

trainloaderからイテレートされたdata(inputs, labels)というタプルですので、分解してあげます。

optimizer.zero_grad()でoptimizerを初期化しています。

outputs = model(inputs)で入力画像をモデルに入力し、出力値をoutputsに代入しています。

loss = criterion(outputs, labels)で損失値を計算しています。

loss.backward()で誤差逆伝播を行なっています。

optimizer.step()で最適化を行なっています。

running_loss += loss.item()から下は損失値の表示をしているだけです。

参考

自己紹介

冒頭に書くと邪魔になるので最後にひっそりと自己紹介させてください。

名前 綿岡晃輝
職業 大学院生 (2019年4月から)
分野 機械学習, 深層学習, 音声処理
Twitter @Wataoka_Koki

Twitterフォローしてね!

15
17
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
15
17