Kerasには、複数レイヤをまとめてインスタンス化するContainerが用意されている。例えば以下のように書くと、2層の全結合層をインスタンス化したshared_layerが作成される。これによってネットワーク内に共通のパラメータを持った層を組み込むことができる。
inputs = Input((3,))
x = Dense(3, activation="sigmoid", use_bias=False)(inputs)
x = Dense(3, activation="sigmoid", use_bias=False)(x)
shared_layer = Container(inputs=inputs, outputs=x, name="shared_layer")
検証すること
kerasではモデルを構成するそれぞれの層のtrainableを変更することでその層の重みを更新するかどうか決定するが、どうやらContainerで作ったインスタンスとContainerを構成する各層でこの設定が独立してるらしい。
ということで以下の2つの設定をそれぞれ変更した4通りについて層の重みが更新されるか検証する。
①Containerでインスタンス化した層のtrainable設定
②Container内部の層のtrainable設定
実行環境
Keras:2.1.6
Python:3.6.4
※keras2.2以降ではContainerの代わりにNetworkを使うらしいので注意
検証モデル
使用するライブラリは以下の通り。
from keras.layers import Input, Dense
from keras.models import Model
from keras.engine.topology import Container
from keras.optimizers import sgd
import numpy as np
3層のニューラルネットワークを構築し、shared_layer内のtrainable設定をいじっておく。
#共有レイヤ作成
inputs = Input((3,))
x = Dense(3, activation="sigmoid", use_bias=False)(inputs)
x = Dense(3, activation="sigmoid", use_bias=False)(x)
shared_layer = Container(inputs=inputs, outputs=x, name="shared_layer")
#テストモデル構築
y = shared_layer(inputs)
model = Model(inputs=inputs, outputs=y)
#shared_layerのうち片方の重みを固定
model.get_layer("shared_layer").layers[1].trainable=False
①shared_layer.trainable=True
trainableの状況確認
print("shared_layer.trainable = " + str(model.get_layer("shared_layer").trainable) + "\n")
print("--each layer.trainable in shared_layer--")
for i in model.get_layer("shared_layer").layers:
print(i.name, i.trainable)
shared_layer.trainable = True
--each layer.trainable in shared_layer--
input_1 False
dense_1 False
dense_2 True
パラメータの確認
for i in model.layers:
print(i.name, i.get_weights())
input_1 []
shared_layer [array([[-0.43590093, -0.49712682, -0.5146408 ],
[ 0.51134324, 0.07132649, 0.2040441 ],
[-0.62061524, 0.42325187, -0.5272429 ]], dtype=float32), array([[-0.09528399, 0.9781587 , 0.22182012],
[ 0.90270543, 0.23723674, -0.18667579],
[ 0.929554 , -0.57312965, 0.7156825 ]], dtype=float32)]
学習
#optimizer作成
SGD = sgd(lr = 10)
#ダミーデータ作成(バッチサイズ5)
train_x = np.random.rand(5,3)
train_y = np.random.rand(5,3)
#重み更新
model.compile(optimizer=SGD, loss="mean_squared_error")
model.fit(x=train_x, y=train_y)
#重みの確認
for i in model.layers:
print(i.name, i.get_weights())
input_1 []
shared_layer [array([[-0.43590093, -0.49712682, -0.5146408 ],
[ 0.51134324, 0.07132649, 0.2040441 ],
[-0.62061524, 0.42325187, -0.5272429 ]], dtype=float32), array([[-0.18501434, 0.9731035 , 0.05751626],
[ 0.8193617 , 0.24262278, -0.386313 ],
[ 0.8491909 , -0.57754445, 0.57252395]], dtype=float32)]
2番めのDense層のみパラメータが更新されている。。。
②shared_layer.trainable=False
設定変更、trainableの状況確認
#設定の変更
model.get_layer("shared_layer").trainable=False
print("shared_layer.trainable = " + str(model.get_layer("shared_layer").trainable) + "\n")
print("--each layer.trainable in shared_layer--")
for i in model.get_layer("shared_layer").layers:
print(i.name, i.trainable)
shared_layer.trainable = False
--each layer.trainable in shared_layer--
input_1 False
dense_1 False
dense_2 True
パラメータの確認
for i in model.layers:
print(i.name, i.get_weights())
input_1 []
shared_layer [array([[-0.43590093, -0.49712682, -0.5146408 ],
[ 0.51134324, 0.07132649, 0.2040441 ],
[-0.62061524, 0.42325187, -0.5272429 ]], dtype=float32), array([[-0.18501434, 0.9731035 , 0.05751626],
[ 0.8193617 , 0.24262278, -0.386313 ],
[ 0.8491909 , -0.57754445, 0.57252395]], dtype=float32)]
学習
#ダミーデータ作成(バッチサイズ5)
train_x = np.random.rand(5,3)
train_y = np.random.rand(5,3)
#重み更新
model.compile(optimizer=SGD, loss="mean_squared_error")
model.fit(x=train_x, y=train_y)
#重みの確認
for i in model.layers:
print(i.name, i.get_weights())
input_1 []
shared_layer [array([[-0.43590093, -0.49712682, -0.5146408 ],
[ 0.51134324, 0.07132649, 0.2040441 ],
[-0.62061524, 0.42325187, -0.5272429 ]], dtype=float32), array([[-0.18501434, 0.9731035 , 0.05751626],
[ 0.8193617 , 0.24262278, -0.386313 ],
[ 0.8491909 , -0.57754445, 0.57252395]], dtype=float32)]
どの層も重みの変化なし。。。
結論
まとめるとこんな感じ。早い話、コンテナのtrainableとコンテナ内の層のtrainableの論理積だった。
コンテナのtrainable\コンテナ内のlayer.trainable | True | False |
---|---|---|
True | ○ | × |
False | × | × |