12
15

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

【PyTorch】畳み込みニューラルネットワーク(CNN)で転移学習・ファインチューニングをする方法(VGG16を題材に添えて)

Last updated at Posted at 2022-11-17

0. はじめに

本記事では、タイトルの通り、VGG16を例にしてPyTorchで転移学習およびファインチューニングを行うためのコーディング方法を紹介します。
「どのような転移学習・ファインチューニングが正しいか?」まで踏み込んだ内容ではありませんのでご注意を!

学習や推論に関する全体的なコードまでは書きませんが、必要あれば以下の記事をご参照ください。
動作環境も基本的にこちらに準じます。

ちなみにVGG16は画像認識で活躍しているネットワークで、「ImageNet」という大規模な画像データセットを学習させたパラメータが公開されています。
他のネットワークであっても、今回の記事内容を参考にすればある程度は転移学習やファインチューニングの操作ができるようになると思います。

1. 転移学習・ファインチューニングとは?

例えば他サイト様『転移学習とファインチューニング』を参考にしますと、

転移学習
学習済みのモデルを、出力層だけ目的のタスク向けに変更し、その出力層のパラメータのみ学習する。

ファインチューニング
学習済みのモデルを、出力層だけ目的のタスク向けに変更し、出力層以外のパラメータも含めて学習する。

ということのようです。
これはつまり、ファインチューニングの部分集合が転移学習といったイメージですかね。

2. 転移学習・ファインチューニングをする方法

VGG16のロード

まず、PyTorchを使ってVGG16をロードしてみましょう。
今回は適当に変数myModelにロードしてみます。

python
from torchvision import models

myModel = models.vgg16(pretrained=True)

このように引数内でpretrained=Trueとすれば、学習済みのVGG16をロードすることができます。
デフォルトはFalseなので、引数なしで呼べばパラメータが初期化されたVGG16となります。

では、この中身はどうなっているのでしょうか?
ということで、コマンドプロンプトなどでmyModelを出力してみると、以下のように出てきます。

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

つまり、VGG16をロードしたmyModelは大きく以下3つ、

  • featuresなる畳み込み層
  • avgpoolなるプーリング層
  • classifierなる全結合層

の3セクションから成り立っているよ、ということです。

転移学習

転移学習では、出力部分のみを自分のタスク用に変更します。
例えば「正解 or 不正解で出力したい」のように2分類問題だったとしましょう。

まず、出力層以外は全てパラメータ固定なので、再学習時に学習済みパラメータが更新されないよう、先に全ての層のパラメータを固定してしまいましょう。

python
# パラメータ固定
for param in myModel.parameters():
    param.requires_grad = False

myModel.parameters()は各パラメータのジェネレータで、イテレートするとパラメータ内容を出してくれます。
そして各パラメータにはrequires_gradというプロパティがついていて、Trueだと勾配計算が必要(つまり学習によってパラメータ更新をしたい)、Falseだと勾配計算は不要(学習しない)、という意味になります。
例えばある層のパラメータを出力してみると、

tensor([ 5.3096e-02,  1.0494e-02,  2.3031e-01,  3.7629e-02,  …, 
        …,
        …, -5.5652e-02,  8.2594e-02], requires_grad=True)

といった感じですね。
このようにデフォルトではTrueなので、上記のようにしてFalseに変更しておきます。

ここで、パラメータを更新したい出力層については、この後の説明でネットワークを上書きしていく際にデフォルトでrequires_grad=Trueとなってくれますのでご安心ください。

さて、それでは次に出力層の変更です。
VGG16は、そのままでは出力が1000分類になっています。
このことは、先ほどmyModelの中身を出力したもののうち、(classifier)(6) で「out_features=1000」と書いてあることからも分かります。
ちなみに、以下のようにして改めて確認することもできます。

コマンドプロンプト
>>> myModel.classifier[6]
出力 : Linear(in_features=4096, out_features=1000, bias=True)

そして、この「Linear(in_features=4096, out_features=1000, bias=True)」は、自分で実装する際にはtorch.nn.Linear(4096, 1000)として実現することができます。

ということで、2分類問題にするには出力を2つにすればよいので、

python
myModel.classifier[6] = torch.nn.Linear(4096, 2)

と更新してあげれば、パラメータ数も含めてリセットされ、2出力のネットワークになります。
あとはガシガシ学習をさせてあげてください。

ファインチューニング

例えば、今回はfeaturesなる畳み込み層の学習済みパラメータはそのままに、classifierなる全結合層全体を再学習させることにしましょう。

ここで、転移学習と異なるのは、ネットワークの形状自体は変化しない層についても再学習させるので、「再学習したい層のパラメータを初期化したいか? それとも学習済みの状態を初期状態としてそこから再学習を開始させたいか?」です。

再学習したい層のパラメータを初期化したい場合

転移学習の時と同じく、まずは一旦VGG16の学習済みパラメータをロードし、全てのパラメータを固定してしまいましょう。

python
from torchvision import models

myModel = models.vgg16(pretrained=True)

for param in myModel.parameters():
    param.requires_grad = False

さて、今回の場合はいくつか方法があるかも知れませんが、ここで紹介するのは「全結合層のネットワークを改めて定義し直してしまう」方法です。
改めて全結合層のネットワークを見てみましょう。

コマンドプロンプト
>>> myModel.classifier

<以下出力>
Sequential(
  (0): Linear(in_features=25088, out_features=4096, bias=True)
  (1): ReLU(inplace=True)
  (2): Dropout(p=0.5, inplace=False)
  (3): Linear(in_features=4096, out_features=4096, bias=True)
  (4): ReLU(inplace=True)
  (5): Dropout(p=0.5, inplace=False)
  (6): Linear(in_features=4096, out_features=1000, bias=True)
)

ということで、ここからはmyModel.classifier自体を上書きしていきましょう。
処理内容は出力層の変更以外はもちろん同じになるよう実装します。上書きすることで現在格納されているパラメータを初期化します。
ちなみにtorch.nn.Sequentialは、各層の処理をひとまとめにラッピングしているだけのもので、ここではあまり気にせず、このSequentialの中で上記の内容を1行ずつ書き下していけばOKです。

python
import torch.nn as nn

# 全結合層を上書き
myModel.classifier = nn.Sequential(
        nn.Linear(25088, 4096),
        nn.ReLU(inplace=True),
        nn.Dropout(0.5),
        nn.Linear(4096, 4096),
        nn.ReLU(inplace=True),
        nn.Dropout(0.5),
        nn.Linear(4096, 2)    
    )

こうすることでmyModel.classifierの学習済みパラメータが初期化され、また先ほど出てきたrequires_gradTrueになっているので、このまま学習を行えば勾配計算によりパラメータ更新してくれます。

学習済みの状態を初期状態として再学習させたい場合

方法としては、

VGG16をpretrained=Trueで学習済みモデルとしてロードしたあと、全結合層はデフォルトのrequires_grad=Trueのままでよく、全結合層以外はrequires_grad=Falseとする。

となります。

まず最初に、学習済みのVGG16をロードしたmyModelのパラメータ名を出力してみます。

コマンドプロンプト
>>> myModel.state_dict().keys()

<以下出力>
odict_keys(['features.0.weight', 'features.0.bias', 'features.2.weight', 'features.2.bias', 'features.5.weight', 'features.5.bias', 'features.7.weight', 'features.7.bias', 'features.10.weight', 'features.10.bias', 'features.12.weight', 'features.12.bias', 'features.14.weight', 'features.14.bias', 'features.17.weight', 'features.17.bias', 'features.19.weight', 'features.19.bias', 'features.21.weight', 'features.21.bias', 'features.24.weight', 'features.24.bias', 'features.26.weight', 'features.26.bias', 'features.28.weight', 'features.28.bias', 'classifier.0.weight', 'classifier.0.bias', 'classifier.3.weight', 'classifier.3.bias', 'classifier.6.weight', 'classifier.6.bias'])

このようにすると、features(畳み込み層)およびclassifier(全結合層)の重み・バイアスが合計32個あることがわかります。

少しややこしいですが、例えば3つ目の'features.2.weight'の中身を見たい場合は、myModel.features[2].weightとします。数字がどうやらインデックス番号に一致するようです。

コマンドプロンプト
>>> myModel.features[2].weight

<以下出力>
Parameter containing:
tensor([[[[-3.0606e-02, -9.8520e-02, -1.3260e-01],
          …,
          [ 2.6805e-02, -9.3975e-02, -4.0504e-02]]]], requires_grad=True)

ここにrequires_grad=Trueが入っていますね。(requires_grad=Falseの場合は、このようにパラメータを呼び出しても明示的には何も出力されませんが、myModel.features[2].weight.requires_gradを呼び出せばちゃんと「False」と返ってきます)

ということで、今回は畳み込み層であるfeaturesrequires_grad=Falseに変更、全結合層はデフォルトのrequires_grad=Trueのままでよいので、これは以下のように実装できます。

python
for param in myModel.features.parameters():
   param.requires_grad = False

以上のようにして、パラメータの初期化や、パラメータの更新停止を操作することができます。

参考

12
15
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
15

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?