はじめに
KerasでネットワークのWeightを固定させて、別のLayerのみ学習したいということはしばしばあります。その時に何を気をつけたら良いかを調べたメモです。
Versions
- Python 3.5.2
- Keras 2.0.2
検証
ここのNormalContainer
の部分のWeightを「更新したい」ときと、「更新したくない」ときがあるとします。
直感的には Container#trainable
というPropertyにFalseを設定すれば良さそうですが、それでちゃんと意図通り動くのかやってみます。
コード
# coding: utf8
import numpy as np
from keras.engine.topology import Input, Container
from keras.engine.training import Model
from keras.layers.core import Dense
from keras.utils.vis_utils import plot_model
def all_weights(m):
return [list(w.reshape((-1))) for w in m.get_weights()]
def random_fit(m):
x1 = np.random.random(10).reshape((5, 2))
y1 = np.random.random(5).reshape((5, 1))
m.fit(x1, y1, verbose=False)
np.random.seed(100)
x = in_x = Input((2, ))
# Create 2 Containers shared same wights
x = Dense(1)(x)
x = Dense(1)(x)
fc_all = Container(in_x, x, name="NormalContainer")
fc_all_not_trainable = Container(in_x, x, name="FixedContainer")
# Create 2 Models using the Containers
x = fc_all(in_x)
x = Dense(1)(x)
model_normal = Model(in_x, x)
x = fc_all_not_trainable(in_x)
x = Dense(1)(x)
model_fixed = Model(in_x, x)
# Set one Container trainable=False
fc_all_not_trainable.trainable = False # Case1
# Compile
model_normal.compile(optimizer="sgd", loss="mse")
model_fixed.compile(optimizer="sgd", loss="mse")
# fc_all_not_trainable.trainable = False # Case2
# Watch which weights are updated by model.fit
print("Initial Weights")
print("Model-Normal: %s" % all_weights(model_normal))
print("Model-Fixed : %s" % all_weights(model_fixed))
random_fit(model_normal)
print("after training Model-Normal")
print("Model-Normal: %s" % all_weights(model_normal))
print("Model-Fixed : %s" % all_weights(model_fixed))
random_fit(model_fixed)
print("after training Model-Fixed")
print("Model-Normal: %s" % all_weights(model_normal))
print("Model-Fixed : %s" % all_weights(model_fixed))
# plot_model(model_normal, "model_normal.png", show_shapes=True)
fc_all
, fc_all_not_trainable
という2つのContainer
を作成します。 後者はtrainable
をFalseにしておきます。
それを使ったmodel_normal
, model_fixed
というModel
をそれぞれ作ります。
期待される動作は、
-
model_normal
をfit()
したときは、それぞれのContainerやその他のWeightが変化する -
model_fixed
をfit()
したときは、それぞれのContainerのWeightは変化せず、その他のWeightは変化する
というものです。
ContainerのWeight | その他のWeight | |
---|---|---|
model_normal#fit() | 変化する | 変化する |
model_fixed#fit() | 変化しない | 変化する |
実行結果: Case1
Initial Weights
Model-Normal: [[1.2912766, -0.53409958], [0.0], [-0.1305927], [0.0], [-0.21052945], [0.0]]
Model-Fixed : [[1.2912766, -0.53409958], [0.0], [-0.1305927], [0.0], [0.37929809], [0.0]]
after training Model-Normal
Model-Normal: [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [-0.21060525], [0.0058233831]]
Model-Fixed : [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [0.37929809], [0.0]]
after training Model-Fixed
Model-Normal: [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [-0.21060525], [0.0058233831]]
Model-Fixed : [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [0.37869808], [0.0091063408]]
期待通りです。
-
after training Model-Normal
の場合は Model-Fixedの[0.37929809], [0.0]
の部分以外は変化しています -
after training Model-Fixed
の場合は逆に Model-Fixedの[0.37929809], [0.0]
の部分のみが変化しています
注意: trainable=Falseは compile() 前に設定しておくこと
上記コードの中で、Model#compile()
の後(Case2となっている場所)で、trainable=False
にしたらどうなるでしょうか。
実行結果: Case2
Initial Weights
Model-Normal: [[1.2912766, -0.53409958], [0.0], [-0.1305927], [0.0], [-0.21052945], [0.0]]
Model-Fixed : [[1.2912766, -0.53409958], [0.0], [-0.1305927], [0.0], [0.37929809], [0.0]]
after training Model-Normal
Model-Normal: [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [-0.21060525], [0.0058233831]]
Model-Fixed : [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [0.37929809], [0.0]]
after training Model-Fixed
Model-Normal: [[1.2910744, -0.53420025], [-0.0002913858], [-0.12900624], [0.0022280237], [-0.21060525], [0.0058233831]]
Model-Fixed : [[1.2910744, -0.53420025], [-0.0002913858], [-0.12900624], [0.0022280237], [0.37869808], [0.0091063408]]
after training Model-Normal
までは同じですが、
after training Model-Fixed
のときに、Container
の重みも一緒に変化しています。
Model#compile()
は、呼び出されると内包しているLayer全てからtrainable_weights
を回収する動きをします。
従って、その時点でtrainable
を設定していないと意味が無いことになってしまいます。
また、 Containerの内包する全てのLayerにtrainable
を設定する必要はない というのもポイントです。Container
は Model
から見ると1つのLayerです。Model
は Container#trainable_weights
を呼び出しますが、Container#trainable
がFalseだと何も返さない(該当箇所)ので、Container
が内容する全てのLayerのWeightが更新対象にならなくなります。これが仕様なのか、単に現段階の実装がそうなっているかはちょっと不明ですが、たぶん意図的だと思います。
さいごに
ちょっともやもやしていたのが解消されました。