古い情報です(2018/02/18更新)
Keras 2.1.3(2018/01/16リリース)から、trainable=Falseで固定されるようになりました
https://github.com/keras-team/keras/releases/tag/2.1.3
trainable attribute in BatchNormalization now disables the updates of the batch statistics (i.e. if trainable == False the layer will now run 100% in inference mode).
以下、2.1.3より前のKerasについての情報です
Kerasで特定の層の重みを学習させないようにする場合は
layer.trainable = False
としますが、BatchNormalization層は移動平均の更新は続くため、完全には固定されません。
バージョンは以下の通り
Python 3.5.2
Keras 2.0.8
tensorflow-gpu 1.3.0
#最も単純な構造で実験
from keras.models import Sequential
from keras.layers import InputLayer, Dense
model = Sequential([InputLayer(input_shape=(1,)), Dense(1, use_bias=False)])
入力・出力ともに値1つです。biasも使用しないので、式で表すと
出力=入力\times x
となります。(これはニューラルネットワークと呼べるのだろうか)
このモデルをベースとします。
まずmodel.trainable = False
によってモデル全体を固定し学習を止め、
そのまま入力=理想出力=1として学習ループを回して出力値の変化の有無を見てみました。
import numpy as np
from keras.models import Model, Sequential
from keras.layers import Input, InputLayer, Dense, BatchNormalization
def test(model):
ones = np.ones((1, 1))
model.trainable = False
model.compile(loss='mse', optimizer='sgd')
model.summary()
#
output = model.predict(ones)
for i in range(5):
model.train_on_batch(ones, ones)
print(i+1, ':', np.linalg.norm(model.predict(ones) - output))
#BNなし
test(Sequential([InputLayer(input_shape=(1,)), Dense(1, use_bias=False)]))
Total params: 1
Trainable params: 0
Non-trainable params: 1
_________________________________________________________________
1 : 0.0
2 : 0.0
3 : 0.0
4 : 0.0
5 : 0.0
パラメータは1つあり、trainable=False
によって Non-trainable です。
出力値は変化しませんでした。
#BNあり
test(Sequential([InputLayer(input_shape=(1,)), Dense(1, use_bias=False), BatchNormalization()]))
Total params: 5
Trainable params: 0
Non-trainable params: 5
_________________________________________________________________
1 : 0.00632572
2 : 0.0126196
3 : 0.018882
4 : 0.0251131
5 : 0.0313128
パラメータは5つあり、全てNon-trainableです。
しかし、出力値の変化が確認できました。
#なぜ?
これは意図された挙動で、fit
やtrain_on_batch
などで学習を行うとき、
trainable=False
の層は逆伝播の重みの更新は停止されるものの、順伝播時の平均・分散の更新は続くために出力値が変化する
ということのようです。
predict
やtest_on_batch
など学習を行わない順伝播であれば出力値の変化は起こりません。
https://github.com/fchollet/keras/issues/4762#issuecomment-299606870
モデルの出力側に重みを固定した内部モデルを接続し入力側へbackpropする場合はこの仕様を考慮する必要があります。
完全にBN層を固定するには、上のリンクではFunctional APIを使用して、training=False
とする方法が推奨されています。
input = Input((1,))
x = Dense(1, use_bias=False)(input)
x = BatchNormalization()(x, training=False)
model = Model(input, x)
test(model)
Total params: 5
Trainable params: 0
Non-trainable params: 5
_________________________________________________________________
1 : 0.0
2 : 0.0
3 : 0.0
4 : 0.0
5 : 0.0
出力値の変化が起きていないことが確認できました。