19
13

More than 5 years have passed since last update.

KerasでTensorFlowの関数を使ったカスタムレイヤーを作る

Posted at

Kerasでちょっと複雑な計算をしようとすると、Kerasのバックエンドで定義されている関数だけでは物足りなくなることがあります。そういうときは豊富なTensorFlowの関数を使ってみましょう。TensorFlowの関数を使ったKerasのカスタムレイヤーは意外と簡単にできたので紹介します。

Kerasのバックエンド関数は実はTensorFlow関数のラッパー

Kerasはバックエンドにより処理系統が異なりますが、TensorFlowがバックエンドのときはTensorFlowの関数をそのまま返しています。TensorFlowがバックエンドのときのソースコードを覗いてみます。こちらにあります。

例えば絶対値を返すバックエンド関数K.abs()はこんな定義になっています。

def abs(x):
    """Element-wise absolute value.
    # Arguments
        x: Tensor or variable.
    # Returns
        A tensor.
    """
    return tf.abs(x)

「ただTensorFlowの関数を返してるだけやん!」。そうです。これ見たときは拍子抜けしました。Kerasの役割とは、同一のAPIで異なるバックエンドで処理できるように保証してあげることなんですね。

つまり、Kerasのバックエンド関数がただのラッパーだということは、このようにKerasとTensorFlowの関数をごちゃまぜに書くこともできます。

import numpy as np
import keras.backend as K
import tensorflow as tf

# Kerasで完結する書き方
x = K.variable(np.arange(20).reshape(5,4))
print(K.eval(x))

# 演算にTensorFlowの関数を使う書き方
y = tf.square(x)
print(K.eval(y))
出力
[[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]
 [12. 13. 14. 15.]
 [16. 17. 18. 19.]]
[[  0.   1.   4.   9.]
 [ 16.  25.  36.  49.]
 [ 64.  81. 100. 121.]
 [144. 169. 196. 225.]
 [256. 289. 324. 361.]]

Kerasで扱っているのは中身はTensorFlowのテンソルなので「tf.square()」と2乗の関数をTensorFlowのものを使っても、K.evalできちんと評価してくれます。(いちいちsess.run()しなくていいし、この書き方見やすくありません?)

TensorFlowの関数を使ったKerasのカスタムレイヤー

MNIST+多層パーセプトロンの例です。まずは全体のコードです。

from keras.layers import Dense, Input, Flatten
from keras.models import Model
from keras.engine.topology import Layer
from keras.datasets import mnist
from keras.utils import to_categorical
import tensorflow as tf
import numpy as np

def create_train_model():
    input = Input((28, 28))
    x = Flatten()(input)
    x = Dense(64, activation="relu")(x)
    x = Dense(10, activation="softmax")(x)
    return Model(input, x)

class PredictLabel(Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, input_shape):
        super().build(input_shape)

    def call(self, inputs):
        return tf.argmax(inputs, axis=-1)

    def compute_output_shape(self, input_shape):
        return (input_shape[0], )

def create_predict_model(train_model):
    x = PredictLabel()(train_model.output)
    return Model(train_model.inputs, x)

def train():
    train_model = create_train_model()
    pred_model = create_predict_model(train_model)

    train_model.summary()
    pred_model.summary()

    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    X_train, X_test = (X_train / 255.0).astype(np.float32), (X_test / 255.0).astype(np.float32)
    y_train, y_test = to_categorical(y_train), to_categorical(y_test)

    train_model.compile("adam", loss="categorical_crossentropy", metrics=["acc"])
    train_model.fit(X_train, y_train, batch_size=128, epochs=3, validation_data=(X_test, y_test))

    # tf関数を使った自作レイヤー
    y_pred_tf = pred_model.predict(X_test[:10])
    # 確認用
    y_pred_np = np.argmax(train_model.predict(X_test[:10]), axis=-1)
    print(y_pred_tf)
    print()
    print(y_pred_np)

if __name__ == "__main__":
    train()

この「PredictLabelクラス」というのがTensorFlowの関数を使ったカスタムレイヤーで、やっていることはSoftmaxで計算されたラベルごとの確率から、最大のもののラベルのインデックスを返すレイヤーです。もっと平たくいえば、数字ごとの推定確率を入力とし、どの数字かを出力するレイヤーです。

通常、このような処理はpredictで確率を計算し、Numpyのargmaxで計算するのが普通です。train()関数の「確認用」と書いてある部分がそれです。

今回は、訓練用のモデルと、訓練用のモデルにPredictLabelレイヤーをつけた推定用のモデルを定義しました。summaryを見ると次のようになります。

訓練用モデル
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         (None, 28, 28)            0
_________________________________________________________________
flatten_1 (Flatten)          (None, 784)               0
_________________________________________________________________
dense_1 (Dense)              (None, 64)                50240
_________________________________________________________________
dense_2 (Dense)              (None, 10)                650
=================================================================
Total params: 50,890
Trainable params: 50,890
Non-trainable params: 0
_________________________________________________________________
推定用モデル
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         (None, 28, 28)            0
_________________________________________________________________
flatten_1 (Flatten)          (None, 784)               0
_________________________________________________________________
dense_1 (Dense)              (None, 64)                50240
_________________________________________________________________
dense_2 (Dense)              (None, 10)                650
_________________________________________________________________
predict_label_1 (PredictLabe (None,)                   0
=================================================================
Total params: 50,890
Trainable params: 50,890
Non-trainable params: 0
_________________________________________________________________

推定用のモデルは、訓練用モデルにPredictLabelをつけただけなので、2つのモデル間で係数は共有されています。PredictLabelの中身を見てみましょう。

class PredictLabel(Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, input_shape):
        super().build(input_shape)

    def call(self, inputs):
        return tf.argmax(inputs, axis=-1)

    def compute_output_shape(self, input_shape):
        return (input_shape[0], )

コンストラクタとbuild()関数の継承は必須です。callの関数が呼ばれたときの処理で、ここでTensorFlowの関数のtf.argmaxを使っています。入力と出力でshapeが変わるのでcompute_output_shapeで(None, 10)→(None, )となるように指定しています。Kerasのカスタムレイヤーの書き方については、詳しくはこちらをご覧ください。

結果

上がTensorFlowの関数を使った例で、下がNumpyで計算した例です。一致してるのが確認できてOKですね。

結果
[7 2 1 0 4 1 4 9 6 9]

[7 2 1 0 4 1 4 9 6 9]

まとめ

Kerasのカスタムレイヤーやバックエンドの演算に、TensorFlowの関数を導入することができました。「TensorFlowでは書きたくないが、TensorFlowの豊富な関数は使いたい」というときに便利です。これで表現の幅がグッと広がるのでぜひ使いこなしたいところです。

19
13
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
19
13