はじめに
Tensorflow2.0のKerasを用いてDeep Learningを試みており,中間層の出力が必要になりました.Kerasのドキュメントにしたがって中間層の出力を得ようとしたのですが,エラーが出てうまく行きませんでした.
調べてもいいやり方が出てこなかった(私の調べ方がまずかっただけ?)のですが,尤もらしいやり方でやったところうまくいったのでまとめることにしました.
- 環境
- tensorflow-gpu==2.0.0-rc1
Kerasでのモデル構築について
tf.keras.Modelを用いたモデル構築のやり方はドキュメントによるとfunction APIを用いる方法とclassを用いる方法の2通りあります.
function APIを用いる方法
import tensorflow as tf
inputs = tf.keras.Input(shape=(3,))
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
classを用いる方法
import tensorflow as tf
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)
model = MyModel()
中間層出力の取得
今回はclassを用いて記述していたのですが,ドキュメントのFAQにはfunction APIの使用を前提とした説明のみがなされており,classを用いている場合には同じ方法ではうまくいきません.
function APIを用いる場合
from keras.models import Model
model = ... # create the original model
layer_name = 'my_layer'
intermediate_layer_model = Model(inputs=model.input,
outputs=model.get_layer(layer_name).output)
intermediate_output = intermediate_layer_model.predict(data)
function APIを用いてる場合には,このように中間層の値を出力するモデルを構築することで中間層出力を得ることができます.
classを用いる場合
function APIを用いる場合は上記のやりかたでできるのですが,classを用いている場合にこのコードを実行すると
AttributeError: Layer my_model is not connected, no input to return.
と怒られてしまいます.
classの方ではfunction APIと違ってinputやoutputを明示的に定めていないのが原因のようです.
そこで(考えてみれば当たり前なのですが)modelを定義したclassに中間層出力を出力するメソッドを追加します.
import tensorflow as tf
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)
def hidden_layer(self, input):
return self.dense1(input)
model = MyModel()
このようにすることで
model.hidden_layer(input)
などとすることで中間層の出力を得ることができます.
最後に
中身自体は全然大したことは書いてないのですが単純に自分が困ってしまったので記事としてまとめさせていただきました.
tensorflow2.0が登場してドキュメントのバージョンとかがちゃんと更新されていなかったりするのがわかりにくかった原因なのかな...