12
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Kerasでの中間層出力の取得について

Last updated at Posted at 2019-12-11

はじめに

Tensorflow2.0のKerasを用いてDeep Learningを試みており,中間層の出力が必要になりました.Kerasのドキュメントにしたがって中間層の出力を得ようとしたのですが,エラーが出てうまく行きませんでした.

調べてもいいやり方が出てこなかった(私の調べ方がまずかっただけ?)のですが,尤もらしいやり方でやったところうまくいったのでまとめることにしました.

  • 環境
    • tensorflow-gpu==2.0.0-rc1

Kerasでのモデル構築について

tf.keras.Modelを用いたモデル構築のやり方はドキュメントによるとfunction APIを用いる方法とclassを用いる方法の2通りあります.

function APIを用いる方法

import tensorflow as tf

inputs = tf.keras.Input(shape=(3,))
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

classを用いる方法

import tensorflow as tf

class MyModel(tf.keras.Model):

  def __init__(self):
    super(MyModel, self).__init__()
    self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
    self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)

  def call(self, inputs):
    x = self.dense1(inputs)
    return self.dense2(x)

model = MyModel()

中間層出力の取得

今回はclassを用いて記述していたのですが,ドキュメントのFAQにはfunction APIの使用を前提とした説明のみがなされており,classを用いている場合には同じ方法ではうまくいきません.

function APIを用いる場合

from keras.models import Model

model = ...  # create the original model

layer_name = 'my_layer'
intermediate_layer_model = Model(inputs=model.input,
                                 outputs=model.get_layer(layer_name).output)
intermediate_output = intermediate_layer_model.predict(data)

function APIを用いてる場合には,このように中間層の値を出力するモデルを構築することで中間層出力を得ることができます.

classを用いる場合

function APIを用いる場合は上記のやりかたでできるのですが,classを用いている場合にこのコードを実行すると

AttributeError: Layer my_model is not connected, no input to return.

と怒られてしまいます.
classの方ではfunction APIと違ってinputやoutputを明示的に定めていないのが原因のようです.

そこで(考えてみれば当たり前なのですが)modelを定義したclassに中間層出力を出力するメソッドを追加します.

import tensorflow as tf

class MyModel(tf.keras.Model):

  def __init__(self):
    super(MyModel, self).__init__()
    self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
    self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)

  def call(self, inputs):
    x = self.dense1(inputs)
    return self.dense2(x)

  def hidden_layer(self, input):
    return self.dense1(input)

model = MyModel()

このようにすることで

model.hidden_layer(input)

などとすることで中間層の出力を得ることができます.

最後に

中身自体は全然大したことは書いてないのですが単純に自分が困ってしまったので記事としてまとめさせていただきました.
tensorflow2.0が登場してドキュメントのバージョンとかがちゃんと更新されていなかったりするのがわかりにくかった原因なのかな...

12
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?