Deep Learning with PyTorchを読んでいて、Batch Normalizationの学習用や推論用のmeanやstandard deviationについて、下記の内容があります。(p223)
Just as for dropout, batch normalization needs to behave differently during training and inference. In fact, at inference time, we want to avoid having the output for a specific input depend on the statistics of the other inputs we’re presenting to the model. As such, we need a way to still normalize, but this time fixing the normalization parameters once and for all.
As minibatches are processed, in addition to estimating the mean and standard deviation for the current minibatch, PyTorch also updates the running estimates for mean and standard deviation that are representative of the whole dataset, as an approximation. This way, when the user specifiesmodel.eval()
and the model contains a batch normalization module, the running estimates are frozen and used for normalization. To unfreeze running estimates and return to using the minibatch statistics, we callmodel.train()
, just as we did for dropout.
- 推論時もnormalizationが必要です。batch normalizationで学習したモデルだから。
- 推論時のnormalizationはサンプル1つ1つに対してやる。
- 推論時normalization用のmeanやstandard deviationは、学習時推定した。
-
model.eval()
が呼び出されたら、全体データの近似meanやstdは更新できない状態。 -
model.train()
が呼び出されたら、更新できる状態。