LoginSignup
1
4

TensorFlowで中間層の値を取得する方法

Posted at

1.この記事の内容

この記事では,TensorFlowで学習したモデルの任意の中間層の計算結果(特徴マップ)を取得する方法を紹介します.
モデルによっては取得できないものもあり,その例についても紹介します.

ソースコードは筆者のGitHubに公開しています.
  → python/tensorflow_sample/Ver2.x/12_get_features

2.TensorFlowで中間層の値を取得する方法

TensorFlowではロードしたモデルに対し,layersで各層のインスタンスの取得や操作をすることができます.

モデルの定義方法によりlayersで取得できる単位がオペレーションであったりFunctional Layerであったりします.
オペレーションに対してはoutput属性を用いて,その層の出力を取得することができます.
Functional Layerに対してはget_config()を用いてFunctional Layerの情報を取得することが可能ですが,必ずしもオペレーションレベルの構造を取得できるとは限らない為,注意が必要です.
※どのようなモデルに対しても,任意の中間層の値を取得する方法をご存知の方は,コメントでご教示いただけますと幸いです.

2-1.中間層の値を取得できる例

単純な構造のCNNで,CIFAR-10で学習した画像識別モデルです.

_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 input_1 (InputLayer)        [(None, 32, 32, 3)]       0

 conv2d (Conv2D)             (None, 30, 30, 32)        896

 max_pooling2d (MaxPooling2D  (None, 15, 15, 32)       0
 )

 conv2d_1 (Conv2D)           (None, 13, 13, 64)        18496

 max_pooling2d_1 (MaxPooling  (None, 6, 6, 64)         0
 2D)

 conv2d_2 (Conv2D)           (None, 4, 4, 64)          36928

 max_pooling2d_2 (MaxPooling  (None, 2, 2, 64)         0
 2D)

 flatten (Flatten)           (None, 256)               0

 dense (Dense)               (None, 64)                16448

 dense_1 (Dense)             (None, 10)                650

=================================================================
Total params: 73,418
Trainable params: 73,418
Non-trainable params: 0
_________________________________________________________________

このモデルでは,layersで各層のインスタンスをオペレーションレベルで取得することができます.
学習済みモデルを読み込んだ後,各層のoutputを取得し,tensorflow.keras.models.Modelのインスタンスを生成しなおします.この時にoutputs=feature_listとすることで推論時(model.predict())に各層の中間層の値が返ってくるようになります.

model.predict()の戻り値がリストで,feature_listの順序に対応しています.

    # --- Load model ---
    saved_model_path = Path('sample_model', 'saved_model')
    model = tf.keras.models.load_model(saved_model_path)
    model.summary()
    
    inputs = model.input
    feature_attr = ['Conv2D']
    feature_list = []
    feature_list.append(model.output)
    for i, layer in enumerate(model.layers):
        if (layer.__class__.__name__ in feature_attr):
            feature_list.append(layer.output)
        elif (layer.__class__.__name__ == 'Functional'):
            layer_config = layer.get_config()
            for j, func_layer in enumerate(layer_config['layers']):
                if (func_layer['class_name'] in feature_attr):
                    feature_list.append(layer.layers[j].output)
    
    # --- Build model to get feature maps ---
    model = tf.keras.models.Model(inputs=inputs, outputs=feature_list)

実際に取得した中間層を画像化した例が下図です.
この例では画像化した対象はconv2dの出力値です.

2-1.png

2-2.中間層の値を取得できない例(YOLOv3)

以降は中間層の値を取得できない例です.
本節はTensorFlow版のYOLOv3(本家Fork)で定義される例で,中間層にInputLayerが含まれるモデルです.

__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to
==================================================================================================
 input (InputLayer)             [(None, 416, 416, 3  0           []
                                )]

 yolo_darknet (Functional)      ((None, None, None,  40620640    ['input[0][0]']
                                 256),
                                 (None, None, None,
                                 512),
                                 (None, None, None,
                                 1024))

 yolo_conv_0 (Functional)       (None, 13, 13, 512)  11024384    ['yolo_darknet[0][2]']

 yolo_conv_1 (Functional)       (None, 26, 26, 256)  2957312     ['yolo_conv_0[0][0]',
                                                                  'yolo_darknet[0][1]']

 yolo_conv_2 (Functional)       (None, 52, 52, 128)  741376      ['yolo_conv_1[0][0]',
                                                                  'yolo_darknet[0][0]']

 yolo_output_0 (Functional)     (None, 13, 13, 3, 8  4984063     ['yolo_conv_0[0][0]']
                                5)

 yolo_output_1 (Functional)     (None, 26, 26, 3, 8  1312511     ['yolo_conv_1[0][0]']
                                5)

 yolo_output_2 (Functional)     (None, 52, 52, 3, 8  361471      ['yolo_conv_2[0][0]']
                                5)

==================================================================================================
Total params: 62,001,757
Trainable params: 61,949,149
Non-trainable params: 52,608
__________________________________________________________________________________________________

model.summary()ではFunctional Layerしか見えませんがlayer.get_config()で分解すると,InputLayerが複数層に含まれることが分かります.

layer.get_config()で分解した結果(クリックして展開)
InputLayer
Conv2D
BatchNormalization
LeakyReLU
ZeroPadding2D
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Add
ZeroPadding2D
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Add
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Add
ZeroPadding2D
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Add
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Add
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Add
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Add
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Add
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Add
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Add
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Add
ZeroPadding2D
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Add
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Add
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Add
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Add
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Add
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Add
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Add
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Add
ZeroPadding2D
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Add
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Add
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Add
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Add
InputLayer
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
InputLayer
Conv2D
BatchNormalization
LeakyReLU
UpSampling2D
InputLayer
Concatenate
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
InputLayer
Conv2D
BatchNormalization
LeakyReLU
UpSampling2D
InputLayer
Concatenate
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
Conv2D
BatchNormalization
LeakyReLU
InputLayer
Conv2D
BatchNormalization
LeakyReLU
Conv2D
Lambda
InputLayer
Conv2D
BatchNormalization
LeakyReLU
Conv2D
Lambda
InputLayer
Conv2D
BatchNormalization
LeakyReLU
Conv2D
Lambda

このようなモデルに対して,中間層の値を取得するためには,途中のInputLayerに対応するinputstf.keras.models.Modelで指定する必要があるようです.

途中のInputLayerの値は中間層の計算結果ですので,事実上,推論時(model.predict())の引数に与えることは不可能であると考えられ,このようなモデルでは中間層の値を取得することができません.

エラーメッセージ

ValueError: Graph disconnected: cannot obtain value for tensor KerasTensor(type_spec=TensorSpec(shape=(None, None, None, 3), dtype=tf.float32, name='input_1'), name='input_1', description="created by layer 'input_1'") at layer "conv2d". The following previous layers were accessed without issue: []

2-3.中間層の値を取得できない例(CenterNetHourGlass104)

最後は,TensorFlow Hubから取得したモデルについてです.
TensorFlow Hubから取得したモデルはLayersModelのインスタンスとして読み込むことができないようで,モデルの分解そのものができません.

エラーメッセージ

AttributeError: '_UserObject' object has no attribute 'summary'

3.さいごに

基本的にはInputLayerは推論対象のデータのみで,極力Functional Layerを使わない実装でモデルを定義する場合に,中間層の値を取得しやすいと思われます.
かなり限定的にも思われますが,自分で実装する場合に限っては注意すればなんとかなりそうです.

tf-keras-visを使うことで,AIの振る舞いを可視化しやすいですが,このプログラムも2-2,2-3で示したようなモデルは可視化がうまくできないようです.

1
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
4