Help us understand the problem. What is going on with this article?

KerasでContainerのTrainable=Falseにしたときの挙動

More than 3 years have passed since last update.

はじめに

KerasでネットワークのWeightを固定させて、別のLayerのみ学習したいということはしばしばあります。その時に何を気をつけたら良いかを調べたメモです。

Versions

  • Python 3.5.2
  • Keras 2.0.2

検証

下記のようなModelを考えます。
model_normal.png

ここのNormalContainer の部分のWeightを「更新したい」ときと、「更新したくない」ときがあるとします。

直感的には Container#trainable というPropertyにFalseを設定すれば良さそうですが、それでちゃんと意図通り動くのかやってみます。

コード

# coding: utf8

import numpy as np
from keras.engine.topology import Input, Container
from keras.engine.training import Model
from keras.layers.core import Dense
from keras.utils.vis_utils import plot_model



def all_weights(m):
    return [list(w.reshape((-1))) for w in m.get_weights()]


def random_fit(m):
    x1 = np.random.random(10).reshape((5, 2))
    y1 = np.random.random(5).reshape((5, 1))
    m.fit(x1, y1, verbose=False)

np.random.seed(100)

x = in_x = Input((2, ))

# Create 2 Containers shared same wights
x = Dense(1)(x)
x = Dense(1)(x)
fc_all = Container(in_x, x, name="NormalContainer")
fc_all_not_trainable = Container(in_x, x, name="FixedContainer")

# Create 2 Models using the Containers
x = fc_all(in_x)
x = Dense(1)(x)
model_normal = Model(in_x, x)

x = fc_all_not_trainable(in_x)
x = Dense(1)(x)
model_fixed = Model(in_x, x)

# Set one Container trainable=False
fc_all_not_trainable.trainable = False  # Case1

# Compile
model_normal.compile(optimizer="sgd", loss="mse")
model_fixed.compile(optimizer="sgd", loss="mse")

# fc_all_not_trainable.trainable = False  # Case2

# Watch which weights are updated by model.fit
print("Initial Weights")
print("Model-Normal: %s" % all_weights(model_normal))
print("Model-Fixed : %s" % all_weights(model_fixed))

random_fit(model_normal)

print("after training Model-Normal")
print("Model-Normal: %s" % all_weights(model_normal))
print("Model-Fixed : %s" % all_weights(model_fixed))

random_fit(model_fixed)

print("after training Model-Fixed")
print("Model-Normal: %s" % all_weights(model_normal))
print("Model-Fixed : %s" % all_weights(model_fixed))


# plot_model(model_normal, "model_normal.png", show_shapes=True)

fc_all, fc_all_not_trainable という2つのContainerを作成します。 後者はtrainableをFalseにしておきます。
それを使ったmodel_normal, model_fixedというModelをそれぞれ作ります。

期待される動作は、

  • model_normalfit() したときは、それぞれのContainerやその他のWeightが変化する
  • model_fixedfit() したときは、それぞれのContainerのWeightは変化せず、その他のWeightは変化する

というものです。

ContainerのWeight その他のWeight
model_normal#fit() 変化する 変化する
model_fixed#fit() 変化しない 変化する

実行結果: Case1

Initial Weights
Model-Normal: [[1.2912766, -0.53409958], [0.0], [-0.1305927], [0.0], [-0.21052945], [0.0]]
Model-Fixed : [[1.2912766, -0.53409958], [0.0], [-0.1305927], [0.0], [0.37929809], [0.0]]
after training Model-Normal
Model-Normal: [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [-0.21060525], [0.0058233831]]
Model-Fixed : [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [0.37929809], [0.0]]
after training Model-Fixed
Model-Normal: [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [-0.21060525], [0.0058233831]]
Model-Fixed : [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [0.37869808], [0.0091063408]]

期待通りです。

  • after training Model-Normal の場合は Model-Fixedの[0.37929809], [0.0] の部分以外は変化しています
  • after training Model-Fixed の場合は逆に Model-Fixedの[0.37929809], [0.0] の部分のみが変化しています

注意: trainable=Falseは compile() 前に設定しておくこと

上記コードの中で、Model#compile() の後(Case2となっている場所)で、trainable=Falseにしたらどうなるでしょうか。

実行結果: Case2

Initial Weights
Model-Normal: [[1.2912766, -0.53409958], [0.0], [-0.1305927], [0.0], [-0.21052945], [0.0]]
Model-Fixed : [[1.2912766, -0.53409958], [0.0], [-0.1305927], [0.0], [0.37929809], [0.0]]
after training Model-Normal
Model-Normal: [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [-0.21060525], [0.0058233831]]
Model-Fixed : [[1.2913349, -0.53398848], [0.00016010582], [-0.13071491], [-0.0012259937], [0.37929809], [0.0]]
after training Model-Fixed
Model-Normal: [[1.2910744, -0.53420025], [-0.0002913858], [-0.12900624], [0.0022280237], [-0.21060525], [0.0058233831]]
Model-Fixed : [[1.2910744, -0.53420025], [-0.0002913858], [-0.12900624], [0.0022280237], [0.37869808], [0.0091063408]]

after training Model-Normal までは同じですが、
after training Model-Fixed のときに、Container の重みも一緒に変化しています。

Model#compile() は、呼び出されると内包しているLayer全てからtrainable_weights を回収する動きをします。
従って、その時点でtrainableを設定していないと意味が無いことになってしまいます。

また、 Containerの内包する全てのLayerにtrainableを設定する必要はない というのもポイントです。ContainerModel から見ると1つのLayerです。ModelContainer#trainable_weights を呼び出しますが、Container#trainable がFalseだと何も返さない(該当箇所)ので、Containerが内容する全てのLayerのWeightが更新対象にならなくなります。これが仕様なのか、単に現段階の実装がそうなっているかはちょっと不明ですが、たぶん意図的だと思います。

さいごに

ちょっともやもやしていたのが解消されました。

mokemokechicken
お気楽極楽会社員です。気ままに投稿しています。
sprocket
"Sprocket(スプロケット)は、Webサイトにおけるコンバージョン(購入・入会・資料請求・問合せ等)を増やしたい企業様向けに、自社開発のWeb接客ツールの導入及びコンバージョン改善コンサルティングを行っている会社です。 "
https://www.sprocket.bz/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした