1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Kerasでの中間層出力の取得について私なりの回答

Posted at

どうやって中間層の値を取得するか

既に同じ問題に遭遇した先人の答え
Kerasでの中間層出力の取得について

私なりの回答

何が何でもカスタムモデルを利用して中間層を取得するスタイルです。

import tensorflow as tf
import numpy as np


class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()

    # noinspection PyAttributeOutsideInit
    def build(self, input_shape):
        self.dense1 = tf.keras.layers.Dense(units=input_shape[-1])
        self.dense2 = tf.keras.layers.Dense(units=input_shape[-1])
        self.dense3 = tf.keras.layers.Dense(units=1)

        self.h2_add = tf.keras.layers.Add()

        self.built = True

    def call(self, inputs, training=None, mask=None):
        h1 = self.dense1(inputs)
        h2 = self.h2_add([self.dense2(h1), inputs])

        # h2の値を取得したい
        outputs = self.dense3(h2)
        return outputs


def main():
    # 目的のモデル
    model1 = MyModel()

    # 踏み台を用意
    inputs = tf.keras.Input(shape=[3])
    outputs = model1(inputs)
    model2 = tf.keras.Model(inputs, outputs)

    # 目的の中間層の値
    h = model2.layers[1].dense3.input

    # 実際に学習などで利用するモデル
    model3 = tf.keras.Model(inputs, [outputs, h])

    x = np.random.randn(2, 3).astype(np.float32)

    # 実際に利用 
    y, h2 = model3(x)
    print(y)
    print(h2)


if __name__ == '__main__':
    main()

実行例

tf.Tensor(
[[1.8381262 ]
 [0.71730536]], shape=(2, 1), dtype=float32)
tf.Tensor(
[[-1.2135824  -1.8959625  -0.9535742 ]
 [-0.61568236  0.62535363 -1.6006978 ]], shape=(2, 3), dtype=float32)

Function APIだけで良くない?

3段階とかいったい何の冗談かといいたくなります。
とても回りくどいと思っています。

カスタムモデルを利用するメリットはある

中間層の値を取得する部分で目的の中間層に名前でアクセスできます。

# 目的の中間層の値
h = model2.layers[1].dense3.input
#                     ↑ MyModel.dense3

model2.layers[1].dense3model1.dense3を指しています。(model1はカスタムモデルMyModelのインスタンス)
カスタムモデルを構築する際に各層に名前を付けるので、中間層を指定する際にその名前でアクセスできます。

Function APIを使う場合は名前を付けられないので、インデックスで指定するしかありません。
model2.layers[1]はまさにそのことの表れです。
インデックスの1はマジックナンバーです。

中間層を名前でアクセスできるのはメリットだと思います。
特にモデルの構造が複雑で入れ子になっている場合は全てのインデックスを適切に指定するのは人間の仕事ではないと思います。

おまけ

h2 = self.h2_add([self.dense2(h1), inputs])

の部分を

h2 = self.dense2(h1) + inputs

にすると

AttributeError: Layer dense_2 is not connected, no input to return.

というようにエラーになります。

普通にカスタムモデルを利用する分には全く問題ないのですが、今回のようにFunction APIで中間層を取得しようとするとエラーになります。

ResNetが登場した当時は、Kerasで実装する際にうっかり+演算子を使ってしまってエラーになった人は多いと思います。
Tensorflowがバージョン2になった今でも、ネットワーク構造を先に定義するということをしているので、バージョン1の頃とその点は変わっていないので、注意は必要です。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?