5
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

お勧めのtf.kerasのカスタムレイヤーの書き方と変数名の挙動

Posted at

はじめに

tf.kerasのカスタムレイヤーでの名前の挙動についてドキュメントにない挙動を見つけたので、そのお知らせです。
ここで言っている"変数名"とはPythonの文法での変数名ではなく、Tensorflowの変数(tf.Variable)に付ける名前(引数として要求される)のことです。

お勧めの書き方の前に変数名についてちょっと説明。

変数名の具体例

下のサンプルコードのself.v1やself.v2のことではなく、my_variable1やmy_variable2のことです。

import tensorflow as tf

# カスタムレイヤーのサンプルコード
# 自作の全結合層
class MyLayer(tf.keras.layers.Layer):
    def __init__(self, output_dim):
        super().__init__()
        self.output_dim = output_dim

        # バイアス項
        # 入力データのサイズには依存していない
        self.v1 = self.add_weight(name='my_variable1', shape=[output_dim])

    def build(self, input_shape):
        # affine行列
        # 入力データのサイズに依存している
        self.v2 = self.add_weight(name='my_variable2', shape=[input_shape[1], self.output_dim])
        self.built = True

    def call(self, inputs, **kwargs):
        return tf.matmul(inputs, self.v2) + self.v1

このあたりの内容は公式のチュートリアルにある内容です。

何か問題があるのか?

とりあえず実行

実際に実行して確認してみます。

model = MyLayer(output_dim=3)
# buildメソッドは初めてデータを入力したときに実行されるので、適当なデータを入れる
x = tf.random.normal(shape=(3, 5))
y = model(x)

print(model.trainable_variables)
         ↓これが名前
[<tf.Variable 'my_variable1:0' shape=(3,) dtype=float32, numpy=array([-0.56484747,  0.00200152,  0.42238712], dtype=float32)>, 
              ↓これが名前
<tf.Variable 'my_layer/my_variable2:0' shape=(5, 3) dtype=float32, numpy=
array([[ 0.47857696, -0.04394728,  0.31904382],
       [ 0.37552172,  0.22522384,  0.07408607],
       [-0.74956644, -0.61549807, -0.41261673],
       [ 0.4850598 , -0.45188528,  0.56900233],
       [-0.39462167,  0.40858668, -0.5422235 ]], dtype=float32)>]

my_variable1:0my_layer/my_variable2:0
何か余計なものがついているけど、変数の名前はそれぞれmy_variable1とmy_variable2であると確認できたので、OK。

本当にそうでしょうか?

レイヤーを重ねた場合

さっきの例に続けて実行してみます。

# 自作のレイヤーを重ねた場合
model = tf.keras.Sequential([
    MyLayer(3),
    MyLayer(3),
    MyLayer(3)
])

[<tf.Variable 'my_variable1:0' shape=(3,) dtype=float32, (略)>,
 <tf.Variable 'sequential/my_layer_1/my_variable2:0' shape=(5, 3) dtype=float32, (略))>,
 <tf.Variable 'my_variable1:0' shape=(3,) dtype=float32, (略)>,
 <tf.Variable 'sequential/my_layer_2/my_variable2:0' shape=(3, 3) dtype=float32, (略)>,
 <tf.Variable 'my_variable1:0' shape=(3,) dtype=float32, (略)>,
 <tf.Variable 'sequential/my_layer_3/my_variable2:0' shape=(3, 3) dtype=float32, (略)]

my_variable1がいっぱいですね(泣)。
区別できません。

Tensorboardで変数のヒストグラムを描いても名前が衝突しまくりで訳がわかりませんでした。

お勧めのカスタムレイヤーの書き方

class MyLayer(tf.keras.layers.Layer):
    def __init__(self, output_dim):
        super().__init__()
        self.output_dim = output_dim
       
    def build(self, input_shape):
        # バイアス項
        # 入力データのサイズには依存していない
        self.v1 = self.add_weight(name='my_variable1', shape=[output_dim])

        # affine行列
        # 入力データのサイズに依存している
        self.v2 = self.add_weight(name='my_variable2', shape=[input_shape[1], self.output_dim])
        self.built = True

    def call(self, inputs, **kwargs):
        return tf.matmul(inputs, self.v2) + self.v1

単純に全ての変数をbuildメソッド内で宣言するだけです。

Tensorflowもバージョン2になってからは、define by runなので、モデルやレイヤーの順序を最初に実行するまで解決できないのだと思います。
そのせいで、__init__メソッドとbiuldメソッドでは大きな違いになっているのだと思います。

ちなみにtf.keras.layers.Denseなどはすべてbuildメソッド内で宣言しているので、安心して使えます。

まとめ

カスタムレイヤーで変数を宣言するときはbuildメソッド内で必ず宣言する。
__init__メソッドでは宣言しない。

余談

名前の処理の挙動の解説

末尾の:0って何?

Tensorflowの仕様で自動で追加されます。
マルチGPUなどで実行する場合は、GPUごとに変数のコピーが作られるので、それぞれに0, 1, 2, ...と順に番号が振られます。
このあたりの仕様はバージョン1の頃も同じです。

バージョン2ではtf.distribute.MirroredStrategyなどを利用してマルチGPUで上と同様のことをすると確認できます。

先頭のmy_layerは何?

my_layerはMyLayerに明示的に名前を設定しなかったときのデフォルトの名前です。
クラス名を自動でスネークケースに変換しています。

また、2個目の例でtf.keras.Sequentialを使った場合はmy_layer_1, my_layer_2, my_layer_3となっています。
これは名前の衝突を避けるために末尾に自動的に追加されます。
1個目の例でmy_layerがある状態で、2個目の例を続けて実行しているので、このようになっています。

これもバージョン1の頃と同じ挙動だと思います。
少なくともTensorflowのラッパーライブラリdm-sonnetでは同様の処理がされます。

5
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?