17
12

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

KerasでContainerのTrainable=Falseにしたときの挙動

Posted at

はじめに

KerasでネットワークのWeightを固定させて、別のLayerのみ学習したいということはしばしばあります。その時に何を気をつけたら良いかを調べたメモです。

Versions

  • Python 3.5.2
  • Keras 2.0.2

検証

下記のようなModelを考えます。
model_normal.png

ここのNormalContainer の部分のWeightを「更新したい」ときと、「更新したくない」ときがあるとします。

直感的には Container#trainable というPropertyにFalseを設定すれば良さそうですが、それでちゃんと意図通り動くのかやってみます。

コード

# coding: utf8

import numpy as np
from keras.engine.topology import Input, Container
from keras.engine.training import Model
from keras.layers.core import Dense
from keras.utils.vis_utils import plot_model



def all_weights(m):
    return [list(w.reshape((-1))) for w in m.get_weights()]


def random_fit(m):
    x1 = np.random.random(10).reshape((5, 2))
    y1 = np.random.random(5).reshape((5, 1))
    m.fit(x1, y1, verbose=False)

np.random.seed(100)

x = in_x = Input((2, ))

# Create 2 Containers shared same wights
x = Dense(1)(x)
x = Dense(1)(x)
fc_all = Container(in_x, x, name="NormalContainer")
fc_all_not_trainable = Container(in_x, x, name="FixedContainer")

# Create 2 Models using the Containers
x = fc_all(in_x)
x = Dense(1)(x)
model_normal = Model(in_x, x)

x = fc_all_not_trainable(in_x)
x = Dense(1)(x)
model_fixed = Model(in_x, x)

# Set one Container trainable=False
fc_all_not_trainable.trainable = False  # Case1

# Compile
model_normal.compile(optimizer="sgd", loss="mse")
model_fixed.compile(optimizer="sgd", loss="mse")

# fc_all_not_trainable.trainable = False  # Case2

# Watch which weights are updated by model.fit
print("Initial Weights")
print("Model-Normal: %s" % all_weights(model_normal))
print("Model-Fixed : %s" % all_weights(model_fixed))

random_fit(model_normal)

print("after training Model-Normal")
print("Model-Normal: %s" % all_weights(model_normal))
print("Model-Fixed : %s" % all_weights(model_fixed))

random_fit(model_fixed)

print("after training Model-Fixed")
print("Model-Normal: %s" % all_weights(model_normal))
print("Model-Fixed : %s" % all_weights(model_fixed))


# plot_model(model_normal, "model_normal.png", show_shapes=True)

fc_all, fc_all_not_trainable という2つのContainerを作成します。 後者はtrainableをFalseにしておきます。
それを使ったmodel_normal, model_fixedというModelをそれぞれ作ります。

期待される動作は、

  • model_normalfit() したときは、それぞれのContainerやその他のWeightが変化する
  • model_fixedfit() したときは、それぞれのContainerのWeightは変化せず、その他のWeightは変化する

というものです。

ContainerのWeight その他のWeight
model_normal#fit() 変化する 変化する
model_fixed#fit() 変化しない 変化する

実行結果: Case1

Initial Weights
Model-Normal: [[1.2912766, -0.53409958], [0.0], [-0.1305927], [0.0], [-0.21052945], [0.0]]
Model-Fixed : [[1.2912766, -0.53409958], [0.0], [-0.1305927], [0.0], [0.37929809], [0.0]]
after training Model-Normal
Model-Normal: [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [-0.21060525], [0.0058233831]]
Model-Fixed : [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [0.37929809], [0.0]]
after training Model-Fixed
Model-Normal: [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [-0.21060525], [0.0058233831]]
Model-Fixed : [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [0.37869808], [0.0091063408]]

期待通りです。

  • after training Model-Normal の場合は Model-Fixedの[0.37929809], [0.0] の部分以外は変化しています
  • after training Model-Fixed の場合は逆に Model-Fixedの[0.37929809], [0.0] の部分のみが変化しています

注意: trainable=Falseは compile() 前に設定しておくこと

上記コードの中で、Model#compile() の後(Case2となっている場所)で、trainable=Falseにしたらどうなるでしょうか。

実行結果: Case2

Initial Weights
Model-Normal: [[1.2912766, -0.53409958], [0.0], [-0.1305927], [0.0], [-0.21052945], [0.0]]
Model-Fixed : [[1.2912766, -0.53409958], [0.0], [-0.1305927], [0.0], [0.37929809], [0.0]]
after training Model-Normal
Model-Normal: [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [-0.21060525], [0.0058233831]]
Model-Fixed : [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [0.37929809], [0.0]]
after training Model-Fixed
Model-Normal: [[1.2910744, -0.53420025], [-0.0002913858], [-0.12900624], [0.0022280237], [-0.21060525], [0.0058233831]]
Model-Fixed : [[1.2910744, -0.53420025], [-0.0002913858], [-0.12900624], [0.0022280237], [0.37869808], [0.0091063408]]

after training Model-Normal までは同じですが、
after training Model-Fixed のときに、Container の重みも一緒に変化しています。

Model#compile() は、呼び出されると内包しているLayer全てからtrainable_weights を回収する動きをします。
従って、その時点でtrainableを設定していないと意味が無いことになってしまいます。

また、 Containerの内包する全てのLayerにtrainableを設定する必要はない というのもポイントです。ContainerModel から見ると1つのLayerです。ModelContainer#trainable_weights を呼び出しますが、Container#trainable がFalseだと何も返さない(該当箇所)ので、Containerが内容する全てのLayerのWeightが更新対象にならなくなります。これが仕様なのか、単に現段階の実装がそうなっているかはちょっと不明ですが、たぶん意図的だと思います。

さいごに

ちょっともやもやしていたのが解消されました。

17
12
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
17
12

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?