Python
DeepLearning
Keras
FineTuning
転移学習

kerasで画像で主流のFine Tuning(転移学習)を、データ学習でもやってみる

More than 1 year has passed since last update.

作成したモデルをオンラインサービスで使う場合、新たに蓄積されるデータを使って既存のモデルを日々更新したいですが、毎日バッチで全部のデータを回すのは時間もお金もかかります。
画像の学習では、VGG16のような学習済みのモデルに判別したい画像を読ませてFine Tuningする方法が一般的です。
そこで今回は、普通のデータで構築したモデルを保存し、そのモデルをFine Tuningしてみました。

ここではポイントだけ紹介しますので、実際に動くサンプルコードは以下から見てください。
https://github.com/tizuo/keras/blob/master/%E8%BB%A2%E7%A7%BB%E5%AD%A6%E7%BF%92%E3%83%86%E3%82%B9%E3%83%88.ipynb

ベースとするモデルを構築する

今回はirisデータを適当に分割して、2回に分けて学習させます。
まずはベースとするモデルを定義します。ポイントとしては、継承したいレイヤーにnameパラメタを入れることだけです。

model_b = Sequential()
model_b.add(Dense(4, input_shape=(4, ), name='l1'))
model_b.add(Activation('relu'))
model_b.add(Dense(4, input_shape=(4, ), name='l2'))
model_b.add(Activation('relu'))
model_b.add(Dense(3, name='cls'))
model_b.add(Activation('softmax'))

モデルの重みを保存

ベース用のデータでfitさせた後、モデルの重みを保存します。

model_b.save_weights('my_model_weights.h5')

入れ先のモデルを準備

レイヤーのnameを対応させてください。この例では新しいデータが過度に反映されるのを防ぐのにDropout層を追加しています。

model_n = Sequential()
model_n.add(Dense(4, input_shape=(4, ), name='l1'))
model_n.add(Activation('relu'))
model_n.add(Dense(4, input_shape=(4, ), name='l2'))
model_n.add(Activation('relu'))
model_n.add(Dropout(0.5))
model_n.add(Dense(3, name='cls'))
model_n.add(Activation('softmax'))

重みのロード&学習

新しく作ったモデルに重みを読み込み、残りのモデルを学習させます。

#重みの読み込み
model_n.load_weights('my_model_weights.h5', by_name=True)

#コンパイル&実行
model_n.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model_n.fit(new_X, new_Y, epochs=50, batch_size=1, verbose=1)

メモリオーバーなどで回らない量のデータを分けて学習させたい場合にも使えそうです。