LoginSignup
12
5

More than 5 years have passed since last update.

Keras(〜2.1.2)でtrainable=FalseをセットしてもBN層は完全に固定されない

Last updated at Posted at 2017-10-21

古い情報です(2018/02/18更新)
Keras 2.1.3(2018/01/16リリース)から、trainable=Falseで固定されるようになりました
https://github.com/keras-team/keras/releases/tag/2.1.3

trainable attribute in BatchNormalization now disables the updates of the batch statistics (i.e. if trainable == False the layer will now run 100% in inference mode).

以下、2.1.3より前のKerasについての情報です


Kerasで特定の層の重みを学習させないようにする場合は

layer.trainable = False 

としますが、BatchNormalization層は移動平均の更新は続くため、完全には固定されません。

バージョンは以下の通り
Python 3.5.2
Keras 2.0.8
tensorflow-gpu 1.3.0

最も単純な構造で実験

from keras.models import Sequential
from keras.layers import InputLayer, Dense
model = Sequential([InputLayer(input_shape=(1,)), Dense(1, use_bias=False)])

入力・出力ともに値1つです。biasも使用しないので、式で表すと

出力=入力\times x

となります。(これはニューラルネットワークと呼べるのだろうか)
このモデルをベースとします。

まずmodel.trainable = Falseによってモデル全体を固定し学習を止め、
そのまま入力=理想出力=1として学習ループを回して出力値の変化の有無を見てみました。

import numpy as np
from keras.models import Model, Sequential
from keras.layers import Input, InputLayer, Dense, BatchNormalization

def test(model):
  ones = np.ones((1, 1))
  model.trainable = False
  model.compile(loss='mse', optimizer='sgd')
  model.summary()
  #
  output = model.predict(ones)
  for i in range(5):
    model.train_on_batch(ones, ones)
    print(i+1, ':', np.linalg.norm(model.predict(ones) - output))

BNなし

test(Sequential([InputLayer(input_shape=(1,)), Dense(1, use_bias=False)]))
Total params: 1
Trainable params: 0
Non-trainable params: 1
_________________________________________________________________
1 : 0.0
2 : 0.0
3 : 0.0
4 : 0.0
5 : 0.0

パラメータは1つあり、trainable=False によって Non-trainable です。
出力値は変化しませんでした。

BNあり

test(Sequential([InputLayer(input_shape=(1,)), Dense(1, use_bias=False), BatchNormalization()]))
Total params: 5
Trainable params: 0
Non-trainable params: 5
_________________________________________________________________
1 : 0.00632572
2 : 0.0126196
3 : 0.018882
4 : 0.0251131
5 : 0.0313128

パラメータは5つあり、全てNon-trainableです。
しかし、出力値の変化が確認できました。

なぜ?

これは意図された挙動で、fittrain_on_batchなどで学習を行うとき、
trainable=Falseの層は逆伝播の重みの更新は停止されるものの、順伝播時の平均・分散の更新は続くために出力値が変化する
ということのようです。
predicttest_on_batchなど学習を行わない順伝播であれば出力値の変化は起こりません。
https://github.com/fchollet/keras/issues/4762#issuecomment-299606870

モデルの出力側に重みを固定した内部モデルを接続し入力側へbackpropする場合はこの仕様を考慮する必要があります。
完全にBN層を固定するには、上のリンクではFunctional APIを使用して、training=Falseとする方法が推奨されています。

input = Input((1,))
x = Dense(1, use_bias=False)(input)
x = BatchNormalization()(x, training=False)
model = Model(input, x)
test(model)
Total params: 5
Trainable params: 0
Non-trainable params: 5
_________________________________________________________________
1 : 0.0
2 : 0.0
3 : 0.0
4 : 0.0
5 : 0.0

出力値の変化が起きていないことが確認できました。

12
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
5