LoginSignup
1
5

More than 5 years have passed since last update.

tensorflow/models/research/slim を転移学習する時の注意点まとめ

Last updated at Posted at 2018-09-04

はじめに

tensorflow/models/research/slim にあるモデルを転移学習させようとしたら半日躓いたので、備忘録として解決策等まとめ。

公式ドキュメントにも転移学習のやり方が書いてるが、クラス数が同じ・全てのパラメータをrestoreさせる等、融通が効かない。

やろうとしたこと

ImageNet等で学習済みのInception-v4(https://arxiv.org/pdf/1602.07261.pdf) を別のタスクに転移学習させようとした。

ネットで調べたところ、tensorflow/models/research/slim 以外に「学習済みパラメータ+コード」のセットが無かったので、これを利用することにした。

モデルの主要部分は tensorflow/models/blob/master/research/slim/nets/inception_v4.py
https://github.com/tensorflow/models/blob/master/research/slim/nets/inception_v4.py
だが、tensorflow/models/blob/master/research/slim/nets/inception_utils.py
https://github.com/tensorflow/models/blob/master/research/slim/nets/inception_utils.py
も使うので要注意。

自分で作った train.py からモデルinception_v4.pyimportして使いたい。

再学習のタスクは2分類。また入力画像の大きさが小さいので、この点の対応も必要。

参考となるのはこのモデルを使った推論コード tensorflow/models/blob/master/research/slim/train_image_classifier.py
https://github.com/tensorflow/models/blob/master/research/slim/train_image_classifier.py

ファイルの用意とimport部分の修正

基本はinception_v4.pyinception_util.pyをカレント・ディレクトリにコピペして使えばいい。

ただ inception_util.py はパッケージ化されていて

inception_v4.py
from nets import inception_utils

となってるので、これを

inception_v4.py
import inception_utils

に変える。

アーキテクチャの設定を変更

257行目で出力のノードを設定してるので、

inception_v4.py
def inception_v4(inputs, num_classes=1001, is_training=True,
                 dropout_keep_prob=0.8,
                 reuse=None,
                 scope='InceptionV4',
                 create_aux_logits=True):

これを2に変更。

inception_v4.py
def inception_v4(inputs, num_classes=2, is_training=True,
                 dropout_keep_prob=0.8,
                 reuse=None,
                 scope='InceptionV4',
                 create_aux_logits=True):

またデフォルトの入力画像は229だが、今回は64なので、まず334行目を

inception_v4.py
inception_v4.default_image_size = 299
inception_v4.py
inception_v4.default_image_size = 64

に変更する。ただこの値はpreprocessのための値を取得するための変数で、モデルの中では使われてない。

また入力画像が小さいとpoolingして行った時に分解能を超えてしまうのでエラーとなる。64x64だと245行目block_reduction_bでプーリングできない、というエラーが発生する。

よってblock_reduction_bを使わず、その直前の出力をGlobal average poolingにつなげる。

具体的には147行目の関数の引数部分

inception_v4.py
def inception_v4_base(inputs, final_endpoint='Mixed_7d', scope=None):

inception_v4.py
def inception_v4_base(inputs, final_endpoint='Mixed_6d', scope=None):

などと適当なblockに変えればいい。

学習済みデータの読み込み

ここがハマったところ。

まずckptファイルを読み込むのでsaver.restore(sess, RESTORED_MODEL_NAME)を使うわけだが、いくつかポイントがある。

今回は最後の全結合はパラメータ数が異なるので、それを除いて読み込みたい。

そこでパラメータを集めたものをvar_to_restoreとして、全結合に相当する名前空間 Logits を除いたものをrestore させる。

inception_v4.py
if RESTORED_MODEL_NAME != '':
    var_to_restore = []
    for num, var1 in enumerate(all_vars):
        _, deter, _ = var1.name.split('/', 2)
        if deter != 'Logits':
            var_to_restore.append(var1)
    # var_to_restore = [v for v in all_vars if not v.name.startswith('Logits')]
    saver = tf.train.Saver(var_to_restore)
    saver.restore(sess, RESTORED_MODEL_NAME)
    print("model ", RESTORED_MODEL_NAME, " is restored.")

ここで通常ならこの直前に

inception_v4.py
logits, end_points = model(x, CLASS_NUM, is_training=True,.....)
all_vars = tf.all_variables()

などと変数を集めるが、これだとslimの場合エラーとなる。

そこで以下のサイト
http://louis-needless.hatenablog.com/entry/notfounderror-bully-me
やコード内の説明を参考にして

inception_v4.py
arg_scope = inception_v4.inception_v4_arg_scope()
with slim.arg_scope(arg_scope):
    logits, end_points = model(x, CLASS_NUM, is_training=True, ...)
all_vars = tf.all_variables()

としたら上手くいった。inception_v4_arg_scopeinception_v4.pyの最後に定義されている。

更に注意点としては、

inception_v4.py
all_vars = tf.all_variables()
....
train = tf.train.AdamOptimizer(learning_rate=0.0005, beta1=0.5).minimize(loss)

などとtraining文より先に変数を集めること。SGDなら問題ないが、AdamとかMomentumは内部に変数を持ってるので、このあとにtf.all_variables()すると、それら変数もかき集められてしまう。

結果、保存ckptファイルに該当するものがない・・・といったエラーが出る。

ポイントまとめ

  1. arg_scope()を使う
  2. 変数をかき集める前にtrainingの1文を書く
1
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
5