LoginSignup
14
19

More than 5 years have passed since last update.

kerasで画像で主流のFine Tuning(転移学習)を、データ学習でもやってみる

Posted at

作成したモデルをオンラインサービスで使う場合、新たに蓄積されるデータを使って既存のモデルを日々更新したいですが、毎日バッチで全部のデータを回すのは時間もお金もかかります。
画像の学習では、VGG16のような学習済みのモデルに判別したい画像を読ませてFine Tuningする方法が一般的です。
そこで今回は、普通のデータで構築したモデルを保存し、そのモデルをFine Tuningしてみました。

ここではポイントだけ紹介しますので、実際に動くサンプルコードは以下から見てください。
https://github.com/tizuo/keras/blob/master/%E8%BB%A2%E7%A7%BB%E5%AD%A6%E7%BF%92%E3%83%86%E3%82%B9%E3%83%88.ipynb

ベースとするモデルを構築する

今回はirisデータを適当に分割して、2回に分けて学習させます。
まずはベースとするモデルを定義します。ポイントとしては、継承したいレイヤーにnameパラメタを入れることだけです。

model_b = Sequential()
model_b.add(Dense(4, input_shape=(4, ), name='l1'))
model_b.add(Activation('relu'))
model_b.add(Dense(4, input_shape=(4, ), name='l2'))
model_b.add(Activation('relu'))
model_b.add(Dense(3, name='cls'))
model_b.add(Activation('softmax'))

モデルの重みを保存

ベース用のデータでfitさせた後、モデルの重みを保存します。

model_b.save_weights('my_model_weights.h5')

入れ先のモデルを準備

レイヤーのnameを対応させてください。この例では新しいデータが過度に反映されるのを防ぐのにDropout層を追加しています。

model_n = Sequential()
model_n.add(Dense(4, input_shape=(4, ), name='l1'))
model_n.add(Activation('relu'))
model_n.add(Dense(4, input_shape=(4, ), name='l2'))
model_n.add(Activation('relu'))
model_n.add(Dropout(0.5))
model_n.add(Dense(3, name='cls'))
model_n.add(Activation('softmax'))

重みのロード&学習

新しく作ったモデルに重みを読み込み、残りのモデルを学習させます。

#重みの読み込み
model_n.load_weights('my_model_weights.h5', by_name=True)

#コンパイル&実行
model_n.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model_n.fit(new_X, new_Y, epochs=50, batch_size=1, verbose=1)

メモリオーバーなどで回らない量のデータを分けて学習させたい場合にも使えそうです。

14
19
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
19