作成したモデルをオンラインサービスで使う場合、新たに蓄積されるデータを使って既存のモデルを日々更新したいですが、毎日バッチで全部のデータを回すのは時間もお金もかかります。
画像の学習では、VGG16のような学習済みのモデルに判別したい画像を読ませてFine Tuningする方法が一般的です。
そこで今回は、普通のデータで構築したモデルを保存し、そのモデルをFine Tuningしてみました。
ここではポイントだけ紹介しますので、実際に動くサンプルコードは以下から見てください。
https://github.com/tizuo/keras/blob/master/%E8%BB%A2%E7%A7%BB%E5%AD%A6%E7%BF%92%E3%83%86%E3%82%B9%E3%83%88.ipynb
ベースとするモデルを構築する
今回はirisデータを適当に分割して、2回に分けて学習させます。
まずはベースとするモデルを定義します。ポイントとしては、継承したいレイヤーにname
パラメタを入れることだけです。
model_b = Sequential()
model_b.add(Dense(4, input_shape=(4, ), name='l1'))
model_b.add(Activation('relu'))
model_b.add(Dense(4, input_shape=(4, ), name='l2'))
model_b.add(Activation('relu'))
model_b.add(Dense(3, name='cls'))
model_b.add(Activation('softmax'))
モデルの重みを保存
ベース用のデータでfitさせた後、モデルの重みを保存します。
model_b.save_weights('my_model_weights.h5')
入れ先のモデルを準備
レイヤーのname
を対応させてください。この例では新しいデータが過度に反映されるのを防ぐのにDropout層を追加しています。
model_n = Sequential()
model_n.add(Dense(4, input_shape=(4, ), name='l1'))
model_n.add(Activation('relu'))
model_n.add(Dense(4, input_shape=(4, ), name='l2'))
model_n.add(Activation('relu'))
model_n.add(Dropout(0.5))
model_n.add(Dense(3, name='cls'))
model_n.add(Activation('softmax'))
重みのロード&学習
新しく作ったモデルに重みを読み込み、残りのモデルを学習させます。
#重みの読み込み
model_n.load_weights('my_model_weights.h5', by_name=True)
#コンパイル&実行
model_n.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model_n.fit(new_X, new_Y, epochs=50, batch_size=1, verbose=1)
メモリオーバーなどで回らない量のデータを分けて学習させたい場合にも使えそうです。