はじめに
tensorflow/models/research/slim
にあるモデルを転移学習させようとしたら半日躓いたので、備忘録として解決策等まとめ。
公式ドキュメントにも転移学習のやり方が書いてるが、クラス数が同じ・全てのパラメータをrestoreさせる等、融通が効かない。
やろうとしたこと
ImageNet等で学習済みのInception-v4(https://arxiv.org/pdf/1602.07261.pdf) を別のタスクに転移学習させようとした。
ネットで調べたところ、tensorflow/models/research/slim
以外に「学習済みパラメータ+コード」のセットが無かったので、これを利用することにした。
モデルの主要部分は tensorflow/models/blob/master/research/slim/nets/inception_v4.py
https://github.com/tensorflow/models/blob/master/research/slim/nets/inception_v4.py
だが、tensorflow/models/blob/master/research/slim/nets/inception_utils.py
https://github.com/tensorflow/models/blob/master/research/slim/nets/inception_utils.py
も使うので要注意。
自分で作った train.py
からモデルinception_v4.py
をimport
して使いたい。
再学習のタスクは2分類。また入力画像の大きさが小さいので、この点の対応も必要。
参考となるのはこのモデルを使った推論コード tensorflow/models/blob/master/research/slim/train_image_classifier.py
https://github.com/tensorflow/models/blob/master/research/slim/train_image_classifier.py
ファイルの用意とimport部分の修正
基本はinception_v4.py
とinception_util.py
をカレント・ディレクトリにコピペして使えばいい。
ただ inception_util.py
はパッケージ化されていて
from nets import inception_utils
となってるので、これを
import inception_utils
に変える。
アーキテクチャの設定を変更
257行目で出力のノードを設定してるので、
def inception_v4(inputs, num_classes=1001, is_training=True,
dropout_keep_prob=0.8,
reuse=None,
scope='InceptionV4',
create_aux_logits=True):
これを2に変更。
def inception_v4(inputs, num_classes=2, is_training=True,
dropout_keep_prob=0.8,
reuse=None,
scope='InceptionV4',
create_aux_logits=True):
またデフォルトの入力画像は229だが、今回は64なので、まず334行目を
inception_v4.default_image_size = 299
inception_v4.default_image_size = 64
に変更する。ただこの値はpreprocessのための値を取得するための変数で、モデルの中では使われてない。
また入力画像が小さいとpoolingして行った時に分解能を超えてしまうのでエラーとなる。64x64だと245行目block_reduction_b
でプーリングできない、というエラーが発生する。
よってblock_reduction_b
を使わず、その直前の出力をGlobal average poolingにつなげる。
具体的には147行目の関数の引数部分
def inception_v4_base(inputs, final_endpoint='Mixed_7d', scope=None):
を
def inception_v4_base(inputs, final_endpoint='Mixed_6d', scope=None):
などと適当なblockに変えればいい。
学習済みデータの読み込み
ここがハマったところ。
まずckpt
ファイルを読み込むのでsaver.restore(sess, RESTORED_MODEL_NAME)
を使うわけだが、いくつかポイントがある。
今回は最後の全結合はパラメータ数が異なるので、それを除いて読み込みたい。
そこでパラメータを集めたものをvar_to_restore
として、全結合に相当する名前空間 Logits を除いたものをrestore させる。
if RESTORED_MODEL_NAME != '':
var_to_restore = []
for num, var1 in enumerate(all_vars):
_, deter, _ = var1.name.split('/', 2)
if deter != 'Logits':
var_to_restore.append(var1)
# var_to_restore = [v for v in all_vars if not v.name.startswith('Logits')]
saver = tf.train.Saver(var_to_restore)
saver.restore(sess, RESTORED_MODEL_NAME)
print("model ", RESTORED_MODEL_NAME, " is restored.")
ここで通常ならこの直前に
logits, end_points = model(x, CLASS_NUM, is_training=True,.....)
all_vars = tf.all_variables()
などと変数を集めるが、これだとslimの場合エラーとなる。
そこで以下のサイト
http://louis-needless.hatenablog.com/entry/notfounderror-bully-me
やコード内の説明を参考にして
arg_scope = inception_v4.inception_v4_arg_scope()
with slim.arg_scope(arg_scope):
logits, end_points = model(x, CLASS_NUM, is_training=True, ...)
all_vars = tf.all_variables()
としたら上手くいった。inception_v4_arg_scope
はinception_v4.py
の最後に定義されている。
更に注意点としては、
all_vars = tf.all_variables()
....
train = tf.train.AdamOptimizer(learning_rate=0.0005, beta1=0.5).minimize(loss)
などとtraining文より先に変数を集めること。SGDなら問題ないが、AdamとかMomentumは内部に変数を持ってるので、このあとにtf.all_variables()
すると、それら変数もかき集められてしまう。
結果、保存ckptファイルに該当するものがない・・・といったエラーが出る。
ポイントまとめ
- arg_scope()を使う
- 変数をかき集める前にtrainingの1文を書く