Edited at

tensorflow/models/research/slim を転移学習する時の注意点まとめ


はじめに

tensorflow/models/research/slim にあるモデルを転移学習させようとしたら半日躓いたので、備忘録として解決策等まとめ。

https://github.com/tensorflow/models/tree/master/research/slim

公式ドキュメントにも転移学習のやり方が書いてるが、クラス数が同じ・全てのパラメータをrestoreさせる等、融通が効かない。


やろうとしたこと

ImageNet等で学習済みのInception-v4(https://arxiv.org/pdf/1602.07261.pdf) を別のタスクに転移学習させようとした。

ネットで調べたところ、tensorflow/models/research/slim 以外に「学習済みパラメータ+コード」のセットが無かったので、これを利用することにした。

モデルの主要部分は tensorflow/models/blob/master/research/slim/nets/inception_v4.py

https://github.com/tensorflow/models/blob/master/research/slim/nets/inception_v4.py

だが、tensorflow/models/blob/master/research/slim/nets/inception_utils.py

https://github.com/tensorflow/models/blob/master/research/slim/nets/inception_utils.py

も使うので要注意。

自分で作った train.py からモデルinception_v4.pyimportして使いたい。

再学習のタスクは2分類。また入力画像の大きさが小さいので、この点の対応も必要。

参考となるのはこのモデルを使った推論コード tensorflow/models/blob/master/research/slim/train_image_classifier.py

https://github.com/tensorflow/models/blob/master/research/slim/train_image_classifier.py


ファイルの用意とimport部分の修正

基本はinception_v4.pyinception_util.pyをカレント・ディレクトリにコピペして使えばいい。

ただ inception_util.py はパッケージ化されていて


inception_v4.py

from nets import inception_utils


となってるので、これを


inception_v4.py

import inception_utils


に変える。


アーキテクチャの設定を変更

257行目で出力のノードを設定してるので、


inception_v4.py

def inception_v4(inputs, num_classes=1001, is_training=True,

dropout_keep_prob=0.8,
reuse=None,
scope='InceptionV4',
create_aux_logits=True):

これを2に変更。


inception_v4.py

def inception_v4(inputs, num_classes=2, is_training=True,

dropout_keep_prob=0.8,
reuse=None,
scope='InceptionV4',
create_aux_logits=True):

またデフォルトの入力画像は229だが、今回は64なので、まず334行目を


inception_v4.py

inception_v4.default_image_size = 299



inception_v4.py

inception_v4.default_image_size = 64


に変更する。ただこの値はpreprocessのための値を取得するための変数で、モデルの中では使われてない。

また入力画像が小さいとpoolingして行った時に分解能を超えてしまうのでエラーとなる。64x64だと245行目block_reduction_bでプーリングできない、というエラーが発生する。

よってblock_reduction_bを使わず、その直前の出力をGlobal average poolingにつなげる。

具体的には147行目の関数の引数部分


inception_v4.py

def inception_v4_base(inputs, final_endpoint='Mixed_7d', scope=None):



inception_v4.py

def inception_v4_base(inputs, final_endpoint='Mixed_6d', scope=None):


などと適当なblockに変えればいい。


学習済みデータの読み込み

ここがハマったところ。

まずckptファイルを読み込むのでsaver.restore(sess, RESTORED_MODEL_NAME)を使うわけだが、いくつかポイントがある。

今回は最後の全結合はパラメータ数が異なるので、それを除いて読み込みたい。

そこでパラメータを集めたものをvar_to_restoreとして、全結合に相当する名前空間 Logits を除いたものをrestore させる。


inception_v4.py

if RESTORED_MODEL_NAME != '':

var_to_restore = []
for num, var1 in enumerate(all_vars):
_, deter, _ = var1.name.split('/', 2)
if deter != 'Logits':
var_to_restore.append(var1)
# var_to_restore = [v for v in all_vars if not v.name.startswith('Logits')]
saver = tf.train.Saver(var_to_restore)
saver.restore(sess, RESTORED_MODEL_NAME)
print("model ", RESTORED_MODEL_NAME, " is restored.")

ここで通常ならこの直前に


inception_v4.py

logits, end_points = model(x, CLASS_NUM, is_training=True,.....)

all_vars = tf.all_variables()

などと変数を集めるが、これだとslimの場合エラーとなる。

そこで以下のサイト

http://louis-needless.hatenablog.com/entry/notfounderror-bully-me

やコード内の説明を参考にして


inception_v4.py

arg_scope = inception_v4.inception_v4_arg_scope()

with slim.arg_scope(arg_scope):
logits, end_points = model(x, CLASS_NUM, is_training=True, ...)
all_vars = tf.all_variables()

としたら上手くいった。inception_v4_arg_scopeinception_v4.pyの最後に定義されている。

更に注意点としては、


inception_v4.py

all_vars = tf.all_variables()

....
train = tf.train.AdamOptimizer(learning_rate=0.0005, beta1=0.5).minimize(loss)

などとtraining文より先に変数を集めること。SGDなら問題ないが、AdamとかMomentumは内部に変数を持ってるので、このあとにtf.all_variables()すると、それら変数もかき集められてしまう。

結果、保存ckptファイルに該当するものがない・・・といったエラーが出る。


ポイントまとめ


  1. arg_scope()を使う

  2. 変数をかき集める前にtrainingの1文を書く