12
11

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

PyTorchでシンプルな畳み込みニューラルネットワークを作ろう

Last updated at Posted at 2018-03-26

現代のコンピュータビジョンには欠かせない畳み込みニューラルネットワーク。その起源は遡ること高度経済成長期後の日本、1980年にNHK放送技術研究所の福島邦彦氏によって書かれたネオコグニトロンに関する論文[1]に見ることができます。そして、それから18年後の1998年に、AT&TのYann LeCun氏らによって、畳み込みニューラルネットワークの原型であるLeNetが実用化されました[2]。こうした先人の功績に思いを馳せつつ、深層学習ライブラリのPyTorchを利用して画像分類のベンチマークを行ってみましょう。

##前提環境
簡便のため環境はクラウドを利用します。ここではAmazon Web ServicesのDeep Learning AMIを利用し、関連パッケージがプリインストールされた状態でインスタンス作成ができ便利です。また、OSはUbuntu、Deep Learning AMIのバージョンは5を利用します。なお、GPUインスタンスはp2.xlargeを選択しました。インスタンスを起動させてSSHでサーバに接続したら、まずPythonの環境をPyTorch用に切り替えましょう。

$ source activate pytorch_p36

##データセット
画像認識ベンチマーク用の新しいデータセットとしてFashion MNIST[3]があります。服、靴、バッグなど、サンプル数6万件の典型的な被服品の画像です。伝統的に使われていた手書き文字のMNISTデータセットの代替として使われることが多いようです。Fashion MNISTは他のよく使われるデータと共にPyTorchであらかじめ用意されているので、読み込みが楽というメリットがあります。ここでは5万件の学習データと1万件のテストデータをPyTorchから読み込みます。なお、初回はダウンロードに少し時間がかかります。

...

fashion_mnist_data = torchvision.datasets.FashionMNIST(
    './fashion-mnist',
    transform=torchvision.transforms.ToTensor(),
    download=True)

data_loader = torch.utils.data.DataLoader(
    dataset=fashion_mnist_data,
    batch_size=16,
    shuffle=True)

fashion_mnist_data_test = torchvision.datasets.FashionMNIST(
    './fashion-mnist',
    transform=torchvision.transforms.ToTensor(),
    train=False,
    download=True)

data_loader_test = torch.utils.data.DataLoader(
    dataset=fashion_mnist_data_test,
    batch_size=16,
    shuffle=True)

...

##ネットワーク定義
畳み込みニューラルネットワークの原型であるLeNetに近いネットワーク構造をここでは実装することにしましょう。LeNetは2層の畳み込み層と2層の全結合層で構成されます。畳み込みが行われた後、データはプーリング層で半分の大きさになります。活性化関数はReLUやソフトマックス関数が用いられます。また、この例の場合、損失関数はネガティブログ損失関数(Negative Log-likelihood Loss)で、これは多クラス交差エントロピー関数とも呼ばれます。出力層の活性化関数はソフトマックス関数ですが、この組み合わせにすると出力層の勾配がy-aというきれいな式になる(ただし、yは正解、aが出力)メリットがあります[4]。

x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)

...

criterion = nn.NLLLoss()

...

##学習
GPUが使える場合、ネットワークやデータや関数を明示的にGPUへ送ることで、高速なGPUの計算資源を利用することができます。具体的には net.cuda() などです。

...

using_cuda = torch.cuda.is_available()
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01)
criterion = nn.NLLLoss()
if using_cuda:
    net.cuda()
    criterion.cuda()

...

##評価
学習の評価には、1万サンプルのテストデータを使用して、正解のラベルと予測ラベルとを比較します。今回だと予測精度は90%から91%が得られました。 一点注意点としては、GPUを利用した場合、GPUから値をメインメモリに移動させる必要があります。

...

_, predicted = torch.max(output.data, 1)
if using_cuda:
    y_predicted = predicted.cpu().numpy()

...

##ソースコード

fmnist-lenet.py
#!/usr/bin/python

import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np


# network definition
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# dataset
fashion_mnist_data = torchvision.datasets.FashionMNIST(
    './fashion-mnist',
    transform=torchvision.transforms.ToTensor(),
    download=True)

data_loader = torch.utils.data.DataLoader(
    dataset=fashion_mnist_data,
    batch_size=16,
    shuffle=True)

fashion_mnist_data_test = torchvision.datasets.FashionMNIST(
    './fashion-mnist',
    transform=torchvision.transforms.ToTensor(),
    train=False,
    download=True)

data_loader_test = torch.utils.data.DataLoader(
    dataset=fashion_mnist_data_test,
    batch_size=16,
    shuffle=True)

# start learning
using_cuda = torch.cuda.is_available()
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01)
criterion = nn.NLLLoss()
if using_cuda:
    net.cuda()
    criterion.cuda()

accuracies = []
epochs = 40
for i in range(epochs):
    for batch, labels in data_loader:
        if using_cuda:
            x = Variable(batch.cuda())
            y = Variable(labels.cuda())
        else:
            x = Variable(batch)
            y = Variable(labels)
        optimizer.zero_grad()
        output = net(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()

    # test
    n_true = 0
    for batch, labels in data_loader_test:
        if using_cuda:
            output = net(Variable(batch.cuda()))
        else:
            output = net(Variable(batch))
        _, predicted = torch.max(output.data, 1)
        if using_cuda:
            y_predicted = predicted.cpu().numpy()
        else:
            y_predicted = predicted.numpy()
        n_true += np.sum(y_predicted == labels.numpy())
    
    total = len(fashion_mnist_data_test)
    accuracy = 100 * n_true / total
    print('epoch: {0}, accuracy: {1}'.format(i, (int)(accuracy)))
    accuracies.append(accuracy)

print(accuracies)

##まとめ
ここでは、畳み込みニューラルネットワークのチュートリアルとして、CNNの原型であるLeNetに近いネットワーク構造による深層学習を実装しました。実用されている画像認識では、VGGやResNetなど、より深く複雑なモデルが利用されます。PyTorchではそれらの有名なネットワークを学習済みパラメータと共に読み込むことができるため、一般的な画像認識にチャレンジしてみてはいかがでしょうか。

##参考情報
[1] Neocognitron: A Self-organizing Neural Network Model for a Mechanism of Pattern Recognition Unaffected by Shift in Position
http://www.cs.princeton.edu/courses/archive/spr08/cos598B/Readings/Fukushima1980.pdf

[2] Gradient Based Learning Applied to Document Recognition
http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf

[3] Fashion-MNIST README-ja
https://qiita.com/masao-classcat/items/9919ffc2946106efb0e5

[4] 多クラス交差エントロピー誤差関数とソフトマックス関数,その美しき微分
https://qiita.com/klis/items/4ad3032d02ff815e09e6

12
11
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
11

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?