LoginSignup
0
0

More than 3 years have passed since last update.

TensorFlowでModel・Layer・Metricをカスタマイズする

Last updated at Posted at 2020-08-17

はじめに

TensorFlowにおいてモデルを自分で実装するときには既に用意されたクラスを継承するのが便利だが,理解するのに結構時間がかかったので忘れないようにメモ.

確認した環境
  • Python 3.6 & TensorFlow 1.15
  • Python 3.7 & TensorFlow 2.2

モデルのカスタマイズ

公式のチュートリアルを参考にして作成.2つの畳込み層を通過した後に入力値を足し合わせ,出力のサイズが入力の半分になるようなResidual Blockを作る.

from tf.keras.layers import Conv2D, BatchNormalization, Add, Activation

class ResidualBlock(tf.keras.Model):

    def __init__(self, filters, kernel_size=2, block_name=''):

        # おまじない
        super(ResidualBlock, self).__init__()

        # 入力値を足し合わせるためにサイズを出力に合わせる
        self.conv0 = Conv2D(filters, 1, strides=2, padding='same', name=block_name+'_conv0')
        self.bn0 = BatchNormalization(name=block_name+'_bn0')

        # 畳み込み層(1層目でサイズを半分にする)
        self.conv1 = Conv2D(filters, kernel_size, strides=2, padding='same', activation='relu', name=block_name+'_conv1')
        self.bn1 = BatchNormalization(name=block_name+'_bn1')
        self.conv2 = Conv2D(filters, kernel_size, padding='same', activation='relu', name=block_name+'_conv2')
        self.bn2 = BatchNormalization(name=block_name+'_bn2')

        # 入力と出力を足し算
        self.add = Add(name=block_name+'_add')
        self.out = Activation('relu', name=block_name+'_out')

    def call(self, x):

        shortcut = self.conv0(x)
        shortcut = self.bn0(shortcut)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.conv2(x)
        x = self.bn2(x)

        x = self.add([shortcut,x])
        x = self.out(x)

        return x

必要な層は__init__で作っておき,callで計算の流れを実装する.

レイヤーのカスタマイズ

カスタマイズしたモデルやtf.keras.applicationsで用意されているモデルは,tf.keras.models.Sequentialを使って組み合わせることで新しいモデルを作ることができる.ヒートマップを作るとき等にモデル内部のレイヤーにアクセスしたいときがあるが,それが自分で調べる限りだと出来なさそう.そこで,モデルの出力をそのまま返すような層を単体で作っておく.

class IdentityLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(IdentityLayer, self).__init__()
    def call(self, x):
        return x

モデルのカスタマイズとほとんど同じ.

メトリックのカスタマイズ

公式のAPIを参考にして作成.Categoricalな入出力に対応するTruePositivesを作る.

class TP_metric(tf.keras.metrics.Metric):

    def __init__(self, name='TP', **kwargs):

        # おまじない
        super(TP_metric, self).__init__(name=name, **kwargs)
        # 内部変数の定義
        self.value = self.add_weight(name='tp', initializer='zeros')

    def update_state(self, y_true, y_pred, sample_weight=None):

        # one_hotベクトルをラベルの配列に変換
        y_true = tf.argmax(y_true, axis=-1)
        y_true = tf.cast(y_true, K.floatx())
        y_pred = tf.argmax(y_pred, axis=-1)
        y_pred = tf.cast(y_pred, K.floatx())

        # 真値と予測値がともに1ならばTPは1増える
        tmp = tf.equal(tf.add(y_true, y_pred), 2)
        tmp = tf.cast(tmp, K.floatx())

        # 内部変数の更新
        self.value.assign_add(tf.reduce_sum(tmp))

    def result(self):
        return self.value

けっこう型にはうるさい(Pythonなのに)ので,こまめにcastをするように心がける.最初はSparse Categoricalに対応するように作ろうとしたが,何故かy_trueのshapeがy_predと同じ値になっていた.代わりにラベルのデータセットの方を調整することで対処.

def one_hot(x):
    return tf.one_hot(x, カテゴリー数)

label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(label_list, tf.int32))
label_ds = label_ds.map(one_hot)

おまけ

tf 1.xとtf 2.xを両方使うようになったことで,eager_executionが少し分かった気になった.

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0