LoginSignup
0

More than 1 year has passed since last update.

TensorFlowのModelClass入門

Posted at

はじめに

TensorFlowのModelClassの知見が溜まってきたので、まとめたいと思います。

以下のサイトを引用しています。

import

import tensorflow as tf
from tensorflow import keras

基本

class Linear(keras.layers.Layer):
    def __init__(self, units=32, input_dim=32):
        super(Linear, self).__init__()
        w_init = tf.random_normal_initializer()
        self.w = tf.Variable(
            initial_value=w_init(shape=(input_dim, units), dtype="float32"),
            trainable=True,
        )
        b_init = tf.zeros_initializer()
        self.b = tf.Variable(
            initial_value=b_init(shape=(units,), dtype="float32"), trainable=True
        )

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

   def calc(self):
       return self(inputs) ** 2

   def getclass(self):
       return self

superで自身のclassを引数にする必要があります。

x = tf.ones((2, 2))
linear_layer = Linear(4, 2) # (1)
y = linear_layer(x) # (2)
z = linear_layer.calc(x) # (3)
print(y)

(1)では、コンストラクターに代入され、
(2)では、callメソッドに代入されます。
(3)では、自身のcallメソッドを呼び出しているため、selfの中に
代入させます。

   def getclass(self):
       return self

ちなみに、pythonでは、return selfとした場合は、インスタンス変数を引き継いだ状態のclassを返すことができます。

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0