Help us understand the problem. What is going on with this article?

Chainerでファインチューニングするときの個人的ベストプラクティス

More than 3 years have passed since last update.

メリークリスマス!!!! @tabe2314です。

この記事では、明日から使えるChainerテクニックとして、既存モデルをファインチューンして新しいモデルをつくる際の個人的なベストプラクティスを紹介します。

ファインチューニング

ニューラルネットを学習するために、別の問題、別のデータセットで学習済みのモデルのパラメータをコピーして、それを新しいニューラルネットのパラメータの初期値として使うことをファインチューニングといいます。
典型的なケースとして、一般物体認識のデータセットであるImageNetで学習したネットワークを物体検出やSemantic Segmentationといった別の問題に使用するといった例があります。

一般的にDeep Learningでは大量の学習データが必要とされていますが、あらかじめ(大量のデータで)学習したモデルを初期値として使いファインチューニングすることで、実際に解きたい問題に関するデータの量が不十分でも十分な性能を達成できる場合があります。また、学習にかかる時間を短縮する効果もあります。

CaffeモデルのChainerモデルへの変換

さて、Chainerでファインチューニングする場合にCaffeのModel Zooで公開されているモデルをベースとしたいことが多いと思います。このような場合、はじめにCaffeのモデルをChainerのモデルへと変換してしまい、そのあとで後述の手順でパラメータをコピーするのがおすすめです。この節ではその変換方法を説明します。
例として、VGG (https://gist.github.com/ksimonyan/3785162f95cd2d5fee77#file-readme-md) を使います。

1. Caffeモデルを読み込む

chainer.links.caffe.CaffeFunction でダウンロードしたCaffeモデルのChainerのモデルとして読み込みます。大きいモデルを読み込む場合にはかなり時間がかかります。

python
from chainer.links.caffe import CaffeFunction
vgg = CaffeFunction('VGG_ILSVRC_19_layers.caffemodel')

読み込み後は、Caffe側で設定されている名前でvgg.conv1_1.W.dataとすれば各レイヤーやそのパラメータにアクセスすることが出来ます。

2. Chainerのモデルとして保存

Caffeのモデル読み込みは時間がかかるため、CaffeFunctionで読み込んだモデルをChainerのモデルとして保存してしまうのがおすすめです。
Chainer1.5からはHDF5によるシリアライズがサポートされています。しかしながら、これを使うためには、仕様上、Chainerのネットワーク定義が必要になるのですがCaffeモデルを読み込んだ場合には別途それを作成するのは面倒です。
このため、cPickleを使って保存するのがよいでしょう。

python
import cPickle as pickle
pickle.dump(vgg, open('vgg.pkl', 'wb'))

新しいモデルにパラメータをコピー

上記の手順でCaffeからコンバートしたモデルや、他の学習済みChainerモデルから新しいモデルにパラメータをコピーする方法を紹介します。
ここで紹介する方法が役に立つのは、コピー元と新しいモデルのネットワーク構成が一部異なる場合です。
構成が全く同じ場合の手順はとても単純で、読み込んだ元モデルをそのまま学習させて更新されたモデルを別の名前で保存するだけでじゅうぶんです。

さて、ネットワーク構成を一部変えてファインチューニングするシチュエーションは例えば以下の様なものがあります。

  • 分類カテゴリが違う問題に適用するために最終層だけ付け替える (e.g., ImageNetで学習したモデルをシーン認識に使う)
  • 前半のConvolution層だけ流用して後段は普通に学習する
  • 全体的な構成はほとんど変わらないけど、一部の層のパラメータを変更する (「n層目のチャンネル数を倍にしよう」)

これらはすべて、以下の copy_model 関数にコピー元モデルと新しいモデルを渡してあげることで実現できます。
この関数は、元モデルの持つLink (パラメータ付の関数) の中から、コピー先の持つLinkと同じ名前かつパラメータのshapeが一致しているものを探して、それらをコピーします。
Chainが入れ子になっている場合には、再帰的にこの処理が行われます。

python
def copy_model(src, dst):
    assert isinstance(src, link.Chain)
    assert isinstance(dst, link.Chain)
    for child in src.children():
        if child.name not in dst.__dict__: continue
        dst_child = dst[child.name]
        if type(child) != type(dst_child): continue
        if isinstance(child, link.Chain):
            copy_model(child, dst_child)
        if isinstance(child, link.Link):
            match = True
            for a, b in zip(child.namedparams(), dst_child.namedparams()):
                if a[0] != b[0]:
                    match = False
                    break
                if a[1].data.shape != b[1].data.shape:
                    match = False
                    break
            if not match:
                print 'Ignore %s because of parameter mismatch' % child.name
                continue
            for a, b in zip(child.namedparams(), dst_child.namedparams()):
                b[1].data = a[1].data
            print 'Copy %s' % child.name

これで元モデルと違う構成の新しいモデルに、共通する部分だけを自動でコピーできるようになります。あとは新しいモデルを好きなように学習しましょう!

まとめ

  • ファインチューニングは実用上とても大事なテクニックです
  • ChainerでCaffeモデルをファインチューンしたい場合には、まず CaffeFunction で読み込んで pickle で保存しましょう
  • 紹介した copy_model 関数を使うことで、元モデルから新しいモデルに共通する部分のパラメータをコピーできます
  • あとは新しいモデルを煮るなり焼くなり!
tabe2314
http://twitter.com/tabe2314 http://github.com/t-abe
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした