どうやって中間層の値を取得するか
既に同じ問題に遭遇した先人の答え
Kerasでの中間層出力の取得について
私なりの回答
何が何でもカスタムモデルを利用して中間層を取得するスタイルです。
import tensorflow as tf
import numpy as np
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
# noinspection PyAttributeOutsideInit
def build(self, input_shape):
self.dense1 = tf.keras.layers.Dense(units=input_shape[-1])
self.dense2 = tf.keras.layers.Dense(units=input_shape[-1])
self.dense3 = tf.keras.layers.Dense(units=1)
self.h2_add = tf.keras.layers.Add()
self.built = True
def call(self, inputs, training=None, mask=None):
h1 = self.dense1(inputs)
h2 = self.h2_add([self.dense2(h1), inputs])
# h2の値を取得したい
outputs = self.dense3(h2)
return outputs
def main():
# 目的のモデル
model1 = MyModel()
# 踏み台を用意
inputs = tf.keras.Input(shape=[3])
outputs = model1(inputs)
model2 = tf.keras.Model(inputs, outputs)
# 目的の中間層の値
h = model2.layers[1].dense3.input
# 実際に学習などで利用するモデル
model3 = tf.keras.Model(inputs, [outputs, h])
x = np.random.randn(2, 3).astype(np.float32)
# 実際に利用
y, h2 = model3(x)
print(y)
print(h2)
if __name__ == '__main__':
main()
実行例
tf.Tensor(
[[1.8381262 ]
[0.71730536]], shape=(2, 1), dtype=float32)
tf.Tensor(
[[-1.2135824 -1.8959625 -0.9535742 ]
[-0.61568236 0.62535363 -1.6006978 ]], shape=(2, 3), dtype=float32)
Function APIだけで良くない?
3段階とかいったい何の冗談かといいたくなります。
とても回りくどいと思っています。
カスタムモデルを利用するメリットはある
中間層の値を取得する部分で目的の中間層に名前でアクセスできます。
# 目的の中間層の値
h = model2.layers[1].dense3.input
# ↑ MyModel.dense3
model2.layers[1].dense3
はmodel1.dense3
を指しています。(model1はカスタムモデルMyModel
のインスタンス)
カスタムモデルを構築する際に各層に名前を付けるので、中間層を指定する際にその名前でアクセスできます。
Function APIを使う場合は名前を付けられないので、インデックスで指定するしかありません。
model2.layers[1]
はまさにそのことの表れです。
インデックスの1はマジックナンバーです。
中間層を名前でアクセスできるのはメリットだと思います。
特にモデルの構造が複雑で入れ子になっている場合は全てのインデックスを適切に指定するのは人間の仕事ではないと思います。
おまけ
h2 = self.h2_add([self.dense2(h1), inputs])
の部分を
h2 = self.dense2(h1) + inputs
にすると
AttributeError: Layer dense_2 is not connected, no input to return.
というようにエラーになります。
普通にカスタムモデルを利用する分には全く問題ないのですが、今回のようにFunction APIで中間層を取得しようとするとエラーになります。
ResNetが登場した当時は、Kerasで実装する際にうっかり+演算子を使ってしまってエラーになった人は多いと思います。
Tensorflowがバージョン2になった今でも、ネットワーク構造を先に定義するということをしているので、バージョン1の頃とその点は変わっていないので、注意は必要です。