Help us understand the problem. What is going on with this article?

Chainerでfine-tuningを行う

More than 3 years have passed since last update.

すでにトレーニングされたモデルからWやbなどの重みを取得することで、より早く学習を進めることができる (fine-tuning)。
モデル自体が完全に同じ場合には当然使えるし、モデルの一部を修正して使用する場合にも、変更の必要がない層に関しては、あらかじめ学習済みのモデルから重みを借用するのが良い。例えば、分類先の画像種類が1000種類から10種類に切り替える場合などは、最終の全結合層のみ切り替える必要はあるが、それ以前の畳み込み層などは変更の必要性は少なく、その部分の重みを使用することができる。

計算時間短縮に結びつく背景には、以下のような点が挙げられる:
- あらかじめ大量の計算資源を使用して収束した時点の結果を使用するため、それまでの計算時間を省略可能。特に、畳み込み層の第一層はガボールフィルタに近似するという研究も多く、あえてそこを新規で学習させる必要性はあまり感じられない
- また、各層の平均出力が学習が進むにつれて徐々に変化していく「共変量シフト」が抑えられる。その結果、重み更新が、共変量シフトの追従に費やされなくなり、より精細な係数の設定が進む

モデルの構造と、初期値の設定の行い方

Chainerにおける各重みは、以下のような場所に保存されている。

# Optは、自分で作成したOptimizer
Target = Opt.target # Targetは、(y=予測結果,t=正解)の比較を答えとして持つClassifierクラス
Model = Target.predictor # Modelは、x=入力から、y=予測結果を返すLinkクラス

conv1_W = Model.conv1.W.data # .dataにより、Variableクラスからarrayを取り出す
conv1_b = Model.conv1.b.data

なので、モデルの初期化が終わったあとに上記の変数を書き換えて重みを設定すれば良さそう・・・と一瞬考えてしまうのだが、Chainerでは1st iterationの中で、入力ファイルの画像サイズ等を参考にW,bなどが作成されるため、それ以前にアクセスしようと思うとエラーを出す。

そこで、Model定義の中で指定できるinitialW, initial_biasに値を指定することで、設定を行う。

class AlexNet(chainer.Chain):
    def __init__(self, class_labels=100,net=None):
        super(AlexNet, self).__init__(
        conv1 = L.Convolution2D(None, 96, 11,stride=4, initialW=net.conv1.W if net else None ,initial_bias=net.conv1.b if net else None),
        ...
        )
        return



重みの固定

重みを読み込んだ層に関して、特に学習の初期には更新が行われないようにした方が、他の層の学習が効率的に進む。そのためには、optimizerのhookに以下のような変更を加えることで、重み更新が起こらないようにできる。

from chainer import cuda
class DelGradient(object):
    name = 'DelGradient'
    def __init__(self, delTgt):
        self.delTgt = delTgt

    def __call__(self, opt):
        for name,param in opt.target.namedparams():
            for d in self.delTgt:
                if d in name:
                    grad = param.grad
                    with cuda.get_device(grad):
                        grad = 0

# Mainの中
...                
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)
optimizer.add_hook(DelGradient(["conv2","conv3","conv4","conv5","fc6"]))
...

Caffeモデルからの重みの読み込み

Caffeモデルから重みを読み込む方法には、Chainerの関数を使用するものと、Caffeの関数を使用するものがある。それぞれの特徴は以下の通り:
- Chainerの関数による読み込み: コードは簡潔。またprototxtのモデル定義も不要。convolution2Dなどに使われる引数のうち、Caffeにしか存在しないもの ("group"など)に関しても、できるだけ対応できるようコーディングされている。しかしながら必ずしもすべてのモデルに対応しているわけではなく、失敗する場合もある。また、読み込みが非常に遅い
- Caffeの関数による読み込み: prototxtのモデル定義が必要なため、若干コードは汚くなる。一方で、読み込みは早く、どのようなモデルも原理的には読み込める。その一方で、上記group引数などによって生じるCaffeとChainerの係数の相違などは、手動で対応しなくてはいけない。

Chainerの関数による読み込み:

下記でモデルを読み込んだ後、その係数を新しいモデルのinitialize引数に指定する。netがpre-trainedなmodelとなる。

from chainer.links.caffe import CaffeFunction
net = CaffeFunction("caffe_alexnet_train_iter_85484.caffemodel")
#重みへのアクセス
net.conv1.W.data # W
net.conv1.b.data # b

Caffeの関数による読み込み

以下のようにしてCaffeモデルを読み込む。netがpre-trainedなmodelとなる。

注意点として、Classifierの第一引数は、caffeのtestに使用するdeploy.prototxtを指定しなくてはいけない。それ以外だと、エラーを吐く。

import caffe
net = caffe.Classifier('deploy.prototxt','caffe_alexnet_train_iter_85484.caffemodel', image_dims = (277, 277))
# 重みへのアクセス
net.params['conv1'][0].data # W
net.params['conv1'][1].data # b
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away