- Keras 2.2.4
- 以下の記事を要約したもの
概要
ResNetのようなBatchNormalization層を含んだ学習済みモデルを元に転移学習するときに、とあるKeras
のバグに行き当たることがある。その内容と対策をまとめた。
再現条件
学習済みモデルに含まれるBN層をフリーズ(layer.trainable = False
)させた上で転移学習を行うときに発生する。
現象
正しく学習・予測ができない。
特にエラーやWarning等が発生するわけではなく、学習がうまく収束しない、なかなか予測の精度が上がらない、といった現象が発生する。
原因
BN層では学習時、予測時に次のような処理を行う。
- 学習時
- 各アクティベーションの平均・分散(ミニバッチ移動平均・分散)で出力値を正規化し、さらに係数$\beta$、$\gamma$で1次変換する
- 予測時
- 学習時に用いた各アクティベーション平均・分散の値と、学習した$\beta$、$\gamma$の値で同じ処理を行う
詳細はググるなりして調べて欲しい。要は学習時にはミニバッチ統計量を使ってゴニョゴニョし、予測時には(ミニバッチなど存在しないので)学習時に使っていたミニバッチ統計量をそのまま用いてやはりゴニョゴニョする。
このBN層をフリーズさせて転移学習するために、期待する動作と実際の動作を以下の表にまとめた。
期待するKeras の動作 |
実際のKeras の動作 |
|
---|---|---|
学習フェーズ | 転移元のデータセットのミニバッチ統計量を使う | 転移先のデータセットのミニバッチ統計量を使う |
予測フェーズ | 転移元のデータセットのミニバッチ統計量を使う | 転移元のデータセットのミニバッチ統計量を使う |
上記の表の通り、フリーズさせているにも関わらず、学習時に転移先のデータセットのミニバッチ統計量を使ってしまうのが原因となる。
回避方法
2つある。
回避方法1:パッチを当てる
元の記事の筆者が問題修正済みブランチを公開している。こちらを使うのが一つ目の方法。
※元記事の冒頭にもある通り、プルリクエストを投げているが、既存の動作を変更してしまうことになる、という理由からマージはされていないとのこと。
回避方法2:特徴量抽出だけ先に行う
学習データである画像と教師ラベルを用いて学習するのではなく、
image --> [Trained Model] --> [Your Model] --> supervised label
<-------------> <---------->
freeze train
以下のように、まず学習データである画像から学習済みモデルを用いて特徴量抽出のみを行い、次に、抽出した特徴量と教師ラベルで学習を行う。
(i) 特徴量抽出
image --> [Trained Model] --> feature
(ii) 抽出した特徴量で学習
feature --> [Your Model] --> supervised label
これは計算量を節約する目的でよく用いられる方法であるが、それだけでなく、上述したKeras
の謎の挙動を回避することもできる!
- 補足:なぜこれで回避できるの?
-
Keras
は内部的にlearning phase
というフラグを持っている。Model#train()
等の学習用のメソッドを実行する際にはlearning phase
が1
に設定され、Model#predict()
等の予測用のメソッドの実行時にはlearning phase
が0
に設定される。 -
learning phase
の値によってBN層の挙動が制御される。1
であればミニバッチ統計量をデータセットから計算し、0
であれば学習時のミニバッチ統計量を流用する。 - そこで第1段階として
Model#prediction()
等で特徴量抽出だけを先にすることで、ミニバッチ統計量を元のデータセットから流用するように動作してくれる
-
その他
この記事は要約したに過ぎないので、やはり元の記事を読むことをオススメする。
また(なるべく記事更新でキャッチアップするつもりだが)最新のKerasで解決されているのかを確認した方が良い。