概要
バッチノーマライゼーションを使って学習した後、推論時にはどう入力を正規化すれば良いのか??
について疑問に思ったので調べてみました。
結論
BNを用いた学習時に小さくないバッチサイズを使うことができていた場合
学習時のバッチ毎の正規化に使った値の移動平均を記憶させておき、その値を使って推論時の入力も正規化しましょう。
BNを用いた学習時に小さいバッチサイズしか使えなかった場合
移動平均ではなく、正規化に使われた値の統計的な平均を計算し、その値を推論時に使いましょう。
内容
バッチノーマライゼーションは、機械学習において学習時、バッチ毎に正規化を行い、リスケールをするパラメータを学習させることでモデルをより早く学習させたりする汎用的な手法です。
【GIF】初心者のためのCNNからバッチノーマライゼーションとその仲間たちまでの解説
がとてもわかりやすくバッチノーマライゼーションについて解説しているのでぜひご覧ください。
今回は、学習をより効率化するバッチノーマライゼーションですが、推論時にはどうバッチノーマライゼーションを使えば良いのか疑問に思ったので少し調べて見ました。
学習時のBNと推論時のBN
学習はいつもミニバッチサイズの元で行われます。しかしながらテストデータや実際のデータを用いる学習時において、入力はバッチの形で入ってこないことも多いです(もちろんバッチサイズの大きさ並みの入力を一気に推論することもありますが)。そのような時にバッチノーマライゼーションを使って学習したモデルに対して、推論時の入力にはどうノーマライゼーションすれば良いのでしょうか。
How does batch normalization behave differently at training time and test time?
ここにわかりやすく解説してあったのでそれを紹介します。
BNを学習時に使わない時
もしネットワークをバッチノーマライゼーションを使わないで学習させた場合を考えます。学習にバッチを使わなかったのであればパラメータの更新に1つのデータのみを使いますし、ミニバッチを使ったのであれば、複数のランダムサンプルされたデータを用いて更新値の平均値を取ることで、パラメータを更新する際の勾配をより効率的に求めることができます。
このようにバッチノーマライゼーションを使わずに学習されたモデルに対しては、推論時にはもちろんバッチ化されたテストデータを用いる必要はありません。
BNを学習に用いた場合
一方で、もしバッチノーマライゼーションを学習に用いた場合、学習時にはバッチノーマライゼーションを用いなかった時のパラメータに加え、バッチノーマライゼーションを使うたびに2つのパラメータを追加で含めることになります。
これは、バッチ毎に正規化されたデータに対し、スケーリングを行い、新たな正規分布へとマッピングを適宜行うためでした。このバッチノーマライゼーションを用いない時に比べて生まれた2つの自由度がネットワークの学習効率を高めていました。
したがって、オペレーションは
- バッチ間のデータを平均ゼロ、分散1に正規化する
- 最適な形に正規化された分布をシフト、スケールする
という2つのオペレーションに分けることができました。
2のオペレーションに関して、学習を通してこのシフトとスケーリングのパラメータに関しては学習され、それが他の重みなどのパラメータとともに収束していきます。しかしながら1に関して、どう正規化すれば良いかはもちろんバッチがどう選ばれるかによって変化するものであり、学習されるパラメータではありません。
推論時には、出力は入力にのみ依存し、入力がバッチノーマライゼーションの場所を通るたびに、学習データによって定められたシフトとスケーリングを施されます。
したがって、1の正規化に関しても、学習時に決定されるべきで、それが推論時に2とともに施されるべきなのです。
ここで、これらのノーマライゼーションパラメタは、データ分布に関してに対して正規化をする、つまりデータ分布に関して線形変換を施すものだったことを思い出すと、スケーリング、シフトのパラメータと学習後に組み合わせることができます。
推論時に使う正規化のためのパラメータを取得するため、学習時のミニバッチ学習時の正規化をするときのパラメータが保存されている必要があります。それらを学習時に保存しておいて、のちに平均を取ることでそれを推論時の入力に施すことができます。
また、学習時に正規化につかったパラメータの移動平均を計算しておくことが推奨されており、単に平均を取るよりもアキュラシーが高まるようです。