はじめに
TensorFlowのModelClassの知見が溜まってきたので、まとめたいと思います。
以下のサイトを引用しています。
import
import tensorflow as tf
from tensorflow import keras
基本
class Linear(keras.layers.Layer):
def __init__(self, units=32, input_dim=32):
super(Linear, self).__init__()
w_init = tf.random_normal_initializer()
self.w = tf.Variable(
initial_value=w_init(shape=(input_dim, units), dtype="float32"),
trainable=True,
)
b_init = tf.zeros_initializer()
self.b = tf.Variable(
initial_value=b_init(shape=(units,), dtype="float32"), trainable=True
)
def call(self, inputs):
return tf.matmul(inputs, self.w) + self.b
def calc(self):
return self(inputs) ** 2
def getclass(self):
return self
superで自身のclassを引数にする必要があります。
x = tf.ones((2, 2))
linear_layer = Linear(4, 2) # (1)
y = linear_layer(x) # (2)
z = linear_layer.calc(x) # (3)
print(y)
(1)では、コンストラクターに代入され、
(2)では、callメソッドに代入されます。
(3)では、自身のcallメソッドを呼び出しているため、selfの中に
代入させます。
def getclass(self):
return self
ちなみに、pythonでは、return selfとした場合は、インスタンス変数を引き継いだ状態のclassを返すことができます。