3
4

More than 3 years have passed since last update.

【PyTorchチュートリアル⑨】Transfer Learning for Computer Vision Tutorial

Last updated at Posted at 2020-11-29

はじめに

前回に引き続き、PyTorch 公式チュートリアル の第9弾です。
今回は Transfer Learning for Computer Vision Tutorial を進めます。

Transfer Learning for Computer Vision Tutorial

このチュートリアルでは、転移学習を使用して畳み込みニューラルネットワークをトレーニングする方法を学びます。転移学習の詳細については、cs231n notesをご覧ください。

(このチュートリアルでは、torchvision の ResNet18を利用して転移学習します。ResNet18 は、深さが18層の畳み込みニューラルネットワークで、torchvision のResNet18 は、ImageNet で事前学習されています。)

引用(抜粋):

実際には、十分なサイズのデータ​​セットを持つことは比較的まれであるため、
畳み込みネットワーク全体を最初から(ランダムに初期化して)トレーニングする人はほとんどいません。
代わりに、非常に大きなデータセット(たとえば、1000のカテゴリを持つ120万の画像を含むImageNet)
でモデルを事前トレーニングしてから、対象のタスクをトレーニングさせるのが一般的です。

転移学習は以下の2つのパターンがあります。:

  • ファインチューニング: ランダムな初期化の代わりに、ImageNet の 1,000データセットでトレーニングされたネットワークなどの、事前にトレーニングされたネットワークで初期化します。以降のトレーニングは通常通りに行います。
  • 転移学習: 事前トレーニング済みのモデルを利用する点はファインチューニングと同じですが、最後のレイヤーを除くすべてのネットワークの重みを固定します。最後のレイヤーはランダムな重みを持つ新しいレイヤーに置き換えられ、このレイヤーのみがトレーニングされます。

(以下に 転移学習、ファインチューニングの説明があります。)
https://udemy.benesse.co.jp/data-science/deep-learning/transfer-learning.html

%matplotlib inline
# License: BSD
# Author: Sasank Chilamkurthy

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

plt.ion()   # interactive mode

Load Data

データの読み込みには、torchvisionおよびtorch.utils.dataパッケージを使用します。
このチュートリアルでは、アリとミツバチを分類するモデルを作成します。アリとミツバチのトレーニング画像と検証画像はそれぞれ約120枚と75枚あります。一から学習するには少ないデータですが、転移学習を利用するため、かなりうまくモデルを作成できるはずです。
このデータセットは、imagenetの非常に小さなサブセットです。
ここ からデータをダウンロードして、解凍します。

%%shell

# hymenoptera_data をダウンロードします。
wget https://download.pytorch.org/tutorial/hymenoptera_data.zip 
# data ディレクトリに解凍します。
mkdir ./data
unzip ./hymenoptera_data.zip -d ./data

(アリとミツバチのトレーニング画像と検証画像はそれぞれ約120枚と75枚あります。と書かれていますが、若干枚数が異なります。)

# トレーニングのためのデータの拡張と検証のための正規化のみ
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

(transforms でそれぞれ以下の処理を行います。)

  • RandomResizedCrop:画像を指定されたサイズ(224×224)にリサイズします。指定された範囲(初期値 0.08- 1.0)、縦横比(初期値 0.75 - 1.33...) で画像をcrop(端を切り落とす)しつつサイズを変更します。
  • RandomHorizontalFlip:画像を指定された確率(初期値0.5)で水平方向に反転します。(鏡に映す感覚)
  • ToTensor:画像を Tensor に変換します。
  • Normalize: Tensor を標準化します。
  • Resize:画像を指定されたサイズ(256×256)にリサイズします。
  • CenterCrop:画像を中央でトリミングします。

Visualize a few images(トレーニング画像を視覚化する)

データの変換を理解するために、いくつかのトレーニング画像を視覚化してみましょう。

def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # プロットが更新されるように少し一時停止します


# トレーニングデータのバッチを取得する
inputs, classes = next(iter(dataloaders['train']))

# バッチからグリッドを作成する
out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes])

ダウンロード.png

(データの変換を理解するには、データローダのシャッフルをやめ、元画像と比較すると分かりやすいです)

dataloaders_unshuffle = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=False, num_workers=4)
                         for x in ['train', 'val']}

# トレーニングデータのバッチを取得する
inputs, classes = next(iter(dataloaders_unshuffle['train']))

# バッチからグリッドを作成する
out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes])

ダウンロード.png

import matplotlib.image as mpimg
img = mpimg.imread('data/hymenoptera_data/train/ants/0013035.jpg')
plt.imshow(img)

0013035.png

img = mpimg.imread('data/hymenoptera_data/train/ants/1030023514_aad5c608f9.jpg')
plt.imshow(img)

1030023514_aad5c608f9.png

img = mpimg.imread('data/hymenoptera_data/train/ants/1095476100_3906d8afde.jpg')
plt.imshow(img)

1095476100_3906d8afde.png

img = mpimg.imread('data/hymenoptera_data/train/ants/1099452230_d1949d3250.jpg')
plt.imshow(img)

1099452230_d1949d3250.png

Training the model

それでは、モデルをトレーニングする処理を記述します。次を行います。

  • 学習率のスケジューリング
  • 最も良いモデルを保存する

以下のコードのパラメータ scheduler は torch.optim.lr_scheduler の learning rate scheduler オブジェクトです。
lr_scheduler は学習率を調整します。

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # 各エポックにはトレーニングフェーズと検証フェーズがあります
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # モデルをトレーニングモードに設定します
            else:
                model.eval()   # モデルを評価モードに設定します

            running_loss = 0.0
            running_corrects = 0

            # データを繰り返し処理します。
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # 勾配をゼロにします。
                optimizer.zero_grad()

                # 順伝播
                # トレーニングの場合は履歴を追跡します
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # トレーニングの場合は逆伝播と最適化を行います
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # 評価
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # model を deepcopy します
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # 最適なモデルの重みをロードします。
    model.load_state_dict(best_model_wts)
    return model

Visualizing the model predictions(モデル予測の視覚化)

予測した値と画像を表示する汎用関数を作成します。

def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title('predicted: {}'.format(class_names[preds[j]]))
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

Finetuning the convnet(ファインチューニング)

ファインチューニングのモデルを作成します。
事前にトレーニングされたモデルをロードし、最終レイヤーをリセットします。

model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
# ここでは、各出力サンプルのサイズが2に設定されています。
# nn.Linear(num_ftrs、len(class_names))で汎用的にすることもできます。
model_ft.fc = nn.Linear(num_ftrs, 2)

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# 学習率を7エポックごとに0.1倍減算します
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

(torchvision のドキュメント を確認すると fc が resnet18 の最終層の線形レイヤーであることが分かります。
上記のコードは resnet18 の最下層の出力次数を num_ftrs で取得し、出力サイズが 2 になるよう線形レイヤーを変更しています。nesnet18との変更点は以下です。)

# nesnet18
(fc): Linear(in_features=512, out_features=1000, bias=True)

# model_ft
(fc): Linear(in_features=512, out_features=2, bias=True)

Train and evaluate(トレーニングと評価)

トレーニングを行います。
CPUでは約15〜25分かかります。GPUでは、1分で終わります。

model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=25)
out
Epoch 0/24
----------
train Loss: 0.5961 Acc: 0.7131
val Loss: 0.1998 Acc: 0.9150

...

Epoch 24/24
----------
train Loss: 0.2326 Acc: 0.9139
val Loss: 0.2490 Acc: 0.9216

Training complete in 1m 30s
Best val Acc: 0.928105
visualize_model(model_ft)

result.png

(6枚の画像だけだと、いまひとつ精度が分からないですが、学習時の 「val Acc」が検証用データ153枚の正解率を表していて、「Best val Acc」が最もよい値です。正解率は92%程度です。)

ConvNet as fixed feature extractor(転移学習)

次に転移学習のモデルを作成します。
転移学習は、最終層を除くすべてのパラメータを固定する必要があります。
勾配が backward() で計算されないようにパラメーターを固定されるよう、requires_grad == Falseを設定します。
詳しくは、こちら のドキュメントをご覧ください。

model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():
    param.requires_grad = False

# モジュールのパラメーターは、デフォルトでrequires_grad = True になっています
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

model_conv = model_conv.to(device)

criterion = nn.CrossEntropyLoss()

# ファインチューニングと違い、最終層のパラメーターのみが最適化されます
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

# 学習率を7エポックごとに0.1倍減算します
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

Train and evaluate(トレーニングと評価)

トレーニングを行います。
CPUだとファインチューニングと比較して半分の時間で済みます。
ほとんどのパラメータで勾配を計算する必要がないためです。
ただし、順伝播を計算する必要があります。

model_conv = train_model(model_conv, criterion, optimizer_conv,
                         exp_lr_scheduler, num_epochs=25)
out
Epoch 0/24
----------
train Loss: 0.7345 Acc: 0.6393
val Loss: 0.4873 Acc: 0.7582

...

Epoch 24/24
----------
train Loss: 0.3105 Acc: 0.8566
val Loss: 0.1785 Acc: 0.9608

Training complete in 1m 20s
Best val Acc: 0.960784
visualize_model(model_conv)

plt.ioff()
plt.show()

result.png

転移学習の「Best val Acc」を確認すると、正解率は96%程度です。ファインチューニングより精度がよくなっています。
ファインチューニングの方が全体のパラーメータを調整する分、精度がよいと思っていたので意外でした。
もしかしたら、今回のモデル「ResNet18」が ImageNet で事前学習されていて、今回利用したアリとミツバチの画像も「imagenetの非常に小さなサブセットです。」ってことなので、事前学習済みデータに含まれてるのかもしれません。

終わりに

今回のチュートリアルでは、事前学習済みのモデルを利用する「ファインチューニング」と「転移学習」を学びました。
次回は「Adversarial Example Generation」を進めてみたいと思います。

履歴

2020/11/29 初版公開
2020/12/20 次回のリンク追加

3
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
4