Keras

Keras(〜2.1.2)でtrainable=FalseをセットしてもBN層は完全に固定されない

古い情報です(2018/02/18更新)
Keras 2.1.3(2018/01/16リリース)から、trainable=Falseで固定されるようになりました
https://github.com/keras-team/keras/releases/tag/2.1.3

trainable attribute in BatchNormalization now disables the updates of the batch statistics (i.e. if trainable == False the layer will now run 100% in inference mode).

以下、2.1.3より前のKerasについての情報です


Kerasで特定の層の重みを学習させないようにする場合は

layer.trainable = False 

としますが、BatchNormalization層は移動平均の更新は続くため、完全には固定されません。

バージョンは以下の通り
Python 3.5.2
Keras 2.0.8
tensorflow-gpu 1.3.0

最も単純な構造で実験

from keras.models import Sequential
from keras.layers import InputLayer, Dense
model = Sequential([InputLayer(input_shape=(1,)), Dense(1, use_bias=False)])

入力・出力ともに値1つです。biasも使用しないので、式で表すと

出力=入力\times x

となります。(これはニューラルネットワークと呼べるのだろうか)
このモデルをベースとします。

まずmodel.trainable = Falseによってモデル全体を固定し学習を止め、
そのまま入力=理想出力=1として学習ループを回して出力値の変化の有無を見てみました。

import numpy as np
from keras.models import Model, Sequential
from keras.layers import Input, InputLayer, Dense, BatchNormalization

def test(model):
  ones = np.ones((1, 1))
  model.trainable = False
  model.compile(loss='mse', optimizer='sgd')
  model.summary()
  #
  output = model.predict(ones)
  for i in range(5):
    model.train_on_batch(ones, ones)
    print(i+1, ':', np.linalg.norm(model.predict(ones) - output))

BNなし

test(Sequential([InputLayer(input_shape=(1,)), Dense(1, use_bias=False)]))
Total params: 1
Trainable params: 0
Non-trainable params: 1
_________________________________________________________________
1 : 0.0
2 : 0.0
3 : 0.0
4 : 0.0
5 : 0.0

パラメータは1つあり、trainable=False によって Non-trainable です。
出力値は変化しませんでした。

BNあり

test(Sequential([InputLayer(input_shape=(1,)), Dense(1, use_bias=False), BatchNormalization()]))
Total params: 5
Trainable params: 0
Non-trainable params: 5
_________________________________________________________________
1 : 0.00632572
2 : 0.0126196
3 : 0.018882
4 : 0.0251131
5 : 0.0313128

パラメータは5つあり、全てNon-trainableです。
しかし、出力値の変化が確認できました。

なぜ?

これは意図された挙動で、fittrain_on_batchなどで学習を行うとき、
trainable=Falseの層は逆伝播の重みの更新は停止されるものの、順伝播時の平均・分散の更新は続くために出力値が変化する
ということのようです。
predicttest_on_batchなど学習を行わない順伝播であれば出力値の変化は起こりません。
https://github.com/fchollet/keras/issues/4762#issuecomment-299606870

モデルの出力側に重みを固定した内部モデルを接続し入力側へbackpropする場合はこの仕様を考慮する必要があります。
完全にBN層を固定するには、上のリンクではFunctional APIを使用して、training=Falseとする方法が推奨されています。

input = Input((1,))
x = Dense(1, use_bias=False)(input)
x = BatchNormalization()(x, training=False)
model = Model(input, x)
test(model)
Total params: 5
Trainable params: 0
Non-trainable params: 5
_________________________________________________________________
1 : 0.0
2 : 0.0
3 : 0.0
4 : 0.0
5 : 0.0

出力値の変化が起きていないことが確認できました。