LoginSignup
19
21

More than 5 years have passed since last update.

Docker環境でPyTorch 〜画像解析〜 #02 モデル訓練&保存編

Last updated at Posted at 2018-07-25

はじめに

株式会社クリエイスCTOの志村です。
前回の続きです。
この記事に最初に行き着いた方は前回の記事を見ていただき、環境を作るところから始めてください。
Docker環境でPyTorch 〜画像解析〜 #01 環境構築編

今回は、自分がコードを作って上田先輩 にいっぱいコメントを書いてもらいました。
畳み込みってなんやねんって思うところもあると思いますが、ソースコードを分解して役割を明確にしているのでフレームワークに沿った動きの理解はしやすいのではないかと思います。

この記事でやること

前回同様CIFAR10のデータセットを利用し、サンプルコードを改変してモデルの訓練と保存をするところまで進めていきます。
ソースと関連するコメントはコードブロック内に記述します。

この記事でやらないこと

  1. Dockerfileを使って公式Imageを元にカスタムImageを作りません
  2. AWSのECS, EKSには対応いたしません
  3. コピペしやすい様にコードブロックの頭のユーザーマーク($, %, #)は記述いたしません

登場人物

画像解析班(この記事の筆者)

  • 志村(コードの作成)、上田(コードの解説)
    はたしてキャプチャー認証を突破することができるのか?!

株価予想班

前提条件

  1. 実験場を作成する目的
  2. 今回は最終にImageをビルドする訳ではない
  3. というわけで、かな〜り緩めにvolumesで実験ソースをマウントして行きます
  4. ホストでIDEをつかってガンガン修正しつつ、コンテナ上でガンガンテストできる環境を目指します

流れ

それでは、ざっくりとこんな感じで進めて行きます
1. CIFAR-10とは
2. モデル作る
3. 学習コード作る
4. テストコードを作る
5. 動かして確認する

Localに次の様なディレクトリ構造で構築して行きます

# ディレクトリ構成
~
└── git  
    ├── deep-learning
    │   └── pytorch (★current workspace)
    │       ├── src
    │       │   ├── car.jpg (テスト用の画像ファイル)
    │       │   ├── network.py (今回作るやつ)
    │       │   ├── learning.py (今回作るやつ)
    │       │   └── test.py (今回作るやつ)
    │       └── docker-compose.yml  
    └── education  
        └── pytorch-examples  
            └── ...

1. CIFAR-10とは

モデルを作るには取り扱うデータについて知る必要があります。  
公式: CIFAR-10データセットのドキュメントから抜粋して説明して行きます

こんな感じのデータセット

CIFAR-10データセットは、10クラスの60000個の32×32カラー画像と1クラスあたり6000個の画像で構成されています。
50000のトレーニング画像と10000のテスト画像があります。
データセットは5つのトレーニングバッチと1つのテストバッチに分けられ、それぞれに10000のイメージがあります。
テストバッチには、各クラスからランダムに選択された1000個の画像が含まれています。
トレーニングバッチには残りの画像がランダムに含まれていますが、一部のトレーニングバッチには、あるクラスの画像が他のクラスよりも多く含まれる場合があります。
その間に、トレーニングバッチには、各クラスの正確に5000個の画像が含まれています。
batches.metaは0-9の範囲の数値ラベルを意味のあるクラス名にマップするASCIIファイルで単なる行ごとに10のクラス名のリストです。
行iのクラス名は、数値ラベルiに対応します。

つまり、テストデータが10個クラスのうちどのクラスなのか?、を導き出すテストに向いているデータセットとなります。

# メタデータ
  batches.meta > クラス名とラベルID(0~9)のマッピング

# トレーニング画像
  data_batch_1 > 10000画像
  data_batch_2 > 10000画像
  data_batch_3 > 10000画像
  data_batch_4 > 10000画像
  data_batch_5 > 10000画像

# テスト画像
  test_batch   > 10000画像

data_batchの中身はこんな感じ

inputs(ドキュメント上ではdata)

10000画像x3072チャンネル値 numpyのuint8sの配列データ。
配列の各行には、32×32カラーイメージが格納されます。
3072チャンネル値の内、最初の1024エントリは赤のチャネル値を含み、次の1024は緑、最後の1024は青を含む。
イメージは行優先順序で格納されるため、配列の最初の32エントリはイメージの最初の行の赤チャネル値になります。  

labels

範囲0〜9の10000の数字のリスト。
インデックスiの番号は、配列データ内のi番目のイメージのラベルを示します。

airplane                                        
automobile                                      
bird                                        
cat                                     
deer                                        
dog                                     
frog                                        
horse                                       
ship                                        
truck

ではコードを作って行きます。

1. モデル作る

画像系ではCNNタイプのモデルをnetwork.pyという名前で作って行きます
3チャンネル画像を定義します。
前述したディレクトリ構成のnetwork.pyを作成します。

network.py
import torch.nn as nn
import torch.nn.functional as F

"""
畳み込みニューラルネットワークを定義する
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ニューラルネットワークセクションからニューラルネットワークをコピーしてから変更します
3チャンネル画像を(1チャンネル画像ではない)定義します。

nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
  SEE: https://pytorch.org/docs/stable/nn.html?highlight=linear#torch.nn.Conv2d
  INFO: kernel_sizeで指定した範囲の画像を見て、畳み込みを行っています。画像における畳み込み計算は、簡単にいうとフィルター処理です

nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)
  SEE: https://pytorch.org/docs/stable/nn.html?highlight=linear#torch.nn.MaxPool2d
  INFO: pool_sizeの範囲を見て、その中で最も大きな値を次の層に渡します。これにより、手書きの線が多少ずれていても、同じような特徴を取り出すことができます。

nn.Linear(in_features, out_features, bias=True)
  SEE: https://pytorch.org/docs/stable/nn.html?highlight=linear#torch.nn.Linear
  INFO: 線形分類器

nn.functional.relu(input, inplace=False)
  SEE: https://pytorch.org/docs/stable/nn.html#torch.nn.functional.relu
  INFO: 整流された線形単位関数を要素ごとに適用します。
"""
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

MODEL_PATH = "model.pth.tar"

2. 学習コード作る

  • torchvisionを使ってCIFAR10テストデータセットをロードおよび正規化する
  • 畳み込みニューラルネットワークを定義する
  • 損失関数を定義する
  • 訓練データ上のネットワークを訓練する
  • 訓練結果を保存します。次回以降保存された訓練結果を再利用してファインチューニングします。
learning.py
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import network
import torch.nn as nn
import torch.optim as optim
import os
from network import MODEL_PATH

"""
CIFAR 10のロードおよび正規化する
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
`torchvision`は、以下のような共通のデータセット用のデータローダーを持っています。
`torchvision.datasets`と `torch.utils.data.DataLoader`は
Imagenet、CIFAR10、MNISTなど、画像用のデータ変換器です。
CIFAR10データセットを使用します。
「飛行機」「自動車」「鳥」「猫」「鹿」「犬」「カエル」「馬」「船」「トラック」などのクラスがあります。
CIFAR-10の画像は、サイズが3×32×32であり、すなわち32×32画素の3チャンネルカラー画像である。

`torchvision.datasets`の出力は範囲[0、1]のPILImage画像です。
正規化された範囲のテンソル[-1,1]に変換します。
"""
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)

net = network.Net()

"""
損失関数とオプティマイザを定義する
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
分類クロスエントロピー損失とSGDを使用して最適化しましょう。

torch.optim.SGD(params, lr=<object object>, momentum=0, dampening=0, weight_decay=0, nesterov=False)
  SEE: https://pytorch.org/docs/master/optim.html#torch.optim.SGD
  INFO: 最適化アルゴリズム

torch.nn.Module.state_dict(destination=None, prefix='', keep_vars=False)
  SEE: https://pytorch.org/docs/stable/nn.html?highlight=state_dict#torch.nn.Module.state_dict
  INFO: モジュール全体の状態を含む辞書を返します。torch.saveに渡す事で学習パラメータのみの保存します。

torch.optim.Optimizer.state_dict()
  SEE: https://pytorch.org/docs/master/optim.html?highlight=state_dict#torch.optim.Optimizer.state_dict
  INFO: オプティマイザの状態を含む辞書を返します。torch.saveに渡す事で学習パラメータのみの保存します。
"""

# loss function
criterion = nn.CrossEntropyLoss()

# optimizer setting
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

def save_checkpoint(state, filename=MODEL_PATH):
    torch.save(state, filename)

"""
学習済みモデルを使ってさらにチューニング
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
前回学習したモデルがあれば、それを使いさらに学習させましょう。
"""
if os.path.exists(MODEL_PATH):
    source = torch.load(MODEL_PATH, map_location=lambda storage, loc: storage)
    net.load_state_dict(source)

"""
ネットワークを訓練する
^^^^^^^^^^^^^^^^^^^^
ここから面白くなります。データイテレータをループし、入力をネットワークに送り、最適化するだけです。
"""
for epoch in range(2):  # loop over the dataset multiple times
    running_loss = 0.0
    net.train()

    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        """
        # 途中経過を毎回保存する事もできます
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': net.state_dict(),
            'optimizer' : optimizer.state_dict()
        })
        """

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

torch.save(net.state_dict(), MODEL_PATH)
print('Finished Training')

3. テストコードを作る

今回は車の画像「car.jpg」を準備します。
学習済みモデルを複数回学習して「car」を帰って来ればゴールです。

test.py
import torch
import torchvision.transforms as transforms
from torch.autograd import Variable #自動微分用
from PIL import Image
import numpy as np
import network
from network import MODEL_PATH

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

net = network.Net()
net.cpu()
source = torch.load(MODEL_PATH, map_location=lambda storage, loc: storage)
net.load_state_dict(source)

"""
画像をモデルで解析可能な状態のテンソルに変換してネットワークに流し込みます
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.unsqueeze
  SEE: https://pytorch.org/docs/stable/torch.html#torch.unsqueeze
"""
img = Image.open('car.jpg')
img = img.resize((32, 32))
img = transform(img).float()
img = Variable(img)
img = img.unsqueeze(0)

with torch.no_grad():
    outputs = net(img)

    #print(outputs)
    _, predicted = torch.max(outputs, 1)
    print(classes[predicted])

4. 動かして確認する

まずは前回の公式Imageからコンテナを起動してアクセスするを参考にコンテナに入ります。

「公式Imageに足りないライブラリを入れる」のが超絶無駄なので1回コンテナを作成したら、コンテナからイメージを作成しておきましょう。

docker commit pytorch myPytorch:latest

docker-compose.ymlのimage名をmyPytorchへ書き換えるだけで
次回の起動では「公式Imageに足りないライブラリを入れる」が不要になります。

学習
python larning.py

テスト
python test.py

結果
bird

んっ!? carじゃないの?
あっ 学習が足りないのか
つーわけで5,6回python larning.pyしたら無事「car」になりました。 

次回

今回は既存のデータセットを利用して学習済みモデルを作成し
さらにその学習済みモデルをファインチューニングする流れを確認しましたが
次回は既存の学習済みモデルを利用した転移学習に手を出せたらなぁ〜と思っています。

19
21
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
19
21