LoginSignup
13
10

More than 1 year has passed since last update.

kerasのModel.fitの処理をカスタマイズする

Last updated at Posted at 2022-02-16

kerasのModel.fitの処理が書いてあるtrain_step()を書き換えてVAEとか蒸留とかGANをやる手順の記事です。

1. 独自の学習ステップの書き方

3つの選択肢があるようです
1. keras.Modelのtrain_stepをoverrideする
2. kerasのカスタムtraining loopを使う
3. tensorflow estimatorを使う

本記事では1をやっていきます

処理の自由度は3が一番高く1が一番低くなりますが、1はkerasのエコシステムがそのまま使えるのが大きなメリットです。コールバックがそのまま使えたり、optimizerの状態を保存できたりと便利なので表現可能なら1で書くのが良さそうです。

1. keras.Modelの処理と同じ動作をするtrain_step

1-1. 準備

importする
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
適当なデータをつくる
def make_data(n=400):
    x = np.linspace(0, 1, n)
    y = x * 2
    return x, y

x, y = make_data(n=400)
プロットする関数
def plot(x, y, y_pred=None):
    plt.figure(figsize=(4, 4))
    plt.scatter(x, y)
    if y_pred is not None:
        plt.scatter(x, y_pred)
    plt.show()

plot(x, y)

image.png

1-2. 元々のkeras.Modelと同じ動作を書く

こちらの記事のコードでやっていきます
https://qiita.com/tonouchi510/items/1203bfe8e6bdb61d9902

Model.fitの動作はtrain_step()に書かれていますので、これを書き換えてカスタマイズしていきますが、まずは元々のkeras.Modelと同じ動作のtrain_stepを見てみましょう

カスタムModelをつくって学習
class MyModel(tf.keras.Model):
    def train_step(self, data):
        x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred)
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        self.compiled_metrics.update_state(y, y_pred, sample_weight)
        return {m.name: m.result() for m in self.metrics}

入ってきたdataを x, y に分けて、推定値である y_pred と y の差からlossを算出して勾配を出してself.optimizerでself.trainable_variablesを更新、self.compiled_metricsの状態を更新、更新済みのmetricsを受け取ってreturnするという処理になっています。

勾配算出に使っているtapeについては公式ページに解説があります
https://www.tensorflow.org/tutorials/customization/autodiff?hl=ja

train_step内の動作についても公式ページに解説があります
https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit?hl=ja

あとはtf.keras.Modelの代わりにMyModelでモデルをつくるだけです

inputs = tf.keras.layers.Input(1)
outputs = tf.keras.layers.Dense(1, activation='linear')(inputs)
model = MyModel(inputs, outputs, name='hoge')
opt = tf.keras.optimizers.SGD(learning_rate=0.1)
model.compile(loss='mse', metrics='mse', optimizer=opt)

model.fit(x, y, validation_data=(x, y), epochs=100)

y_pred = model.predict(x)
plot(x, y, y_pred)

image.png

2. 標準のcompiled_lossをそのまま使える場合

Model.compile()で指定したcompiled_lossを使うとepochごとのlossのクリアやlossの表示などを自分でやらなくて済んで楽なので、まずはこれをやってみたいと思います。

ここではcompiled_lossをそのまま使える例として、回帰モデルの蒸留をやってみます。xをシャッフルしてmixupしたものを受け取って、教師モデルで推定したyを使って生徒モデルを学習させて、教師モデルと同じ結果を出力する生徒モデルをつくるイメージでやってみます。

2-1. 教師モデルの作成

まず教師モデルを作ります。

教師モデルの作成
def make_model(name):
    inputs = tf.keras.layers.Input(1)
    outputs = tf.keras.layers.Dense(1, activation='linear')(inputs)
    model = tf.keras.models.Model(inputs, outputs, name=name)
    opt = tf.keras.optimizers.SGD(learning_rate=0.1)
    model.compile(loss='mse', optimizer=opt)
    return model

teacher = make_model('teacher')
student = make_model('student')
teacher.fit(x, y, validation_data=(x, y), epochs=100)

y_pred = teacher.predict(x)
plot(x, y, y_pred)

image.png

本来は大きいけど精度が高い教師モデルを小さい生徒モデルに再現させる為にやるものなので、教師モデルが上記のように小さいならわざわざ蒸留する必要なんかないんですが書き方の例としてこれでやってみます。

2-2. 変更は最小限にしたカスタムモデル

とりあえず変更を最小限にやりたいことをやってみます

カスタムモデルで学習
class Distiller(tf.keras.Model):
    def __init__(self, inputs, outputs, teacher, **kwargs):
        super(Distiller, self).__init__(inputs, outputs, **kwargs)
        self.teacher = teacher

    def train_step(self, data):
        x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
        y_teacher = self.teacher(x, training=False)  # ← 変更箇所

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y_teacher, y_pred)  # ← 変更箇所
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients((grad, var) for (grad, var) in zip(gradients, model.trainable_variables) 
                                       if grad is not None)  # ← 変更箇所

        self.compiled_metrics.update_state(y, y_pred, sample_weight)
        return {m.name: m.result() for m in self.metrics}    

シャッフルしてmixupするのは別途やっている想定で、受け取ったxに対する教師モデルの推定値y_teacherをstudentモデルが学習します。このとき、教師モデルは学習対象にしたくないので、training=Falseとして更新が掛からないようにしています。

training=Falseにするとその部分の勾配がNoneになりますが、そのまま渡すとself.optimizer.apply_gradients()が警告メッセージを出してきちゃうので、Noneでないものだけ渡すようにしています。(参考にしたページ)

ということで生徒モデルを学習してみます

学習してみる
inputs = tf.keras.layers.Input(1)
outputs = tf.keras.layers.Dense(1, activation='linear')(inputs)
model = Distiller(inputs, outputs, teacher)
opt = tf.keras.optimizers.SGD(learning_rate=0.1)
model.compile(loss='mse', metrics='mae', optimizer=opt)#, run_eagerly=True)

model.fit(x, y, validation_data=(x, y), epochs=100)
y_pred = model.predict(x)
plot(x, y, y_pred)

image.png
このモデルのsummaryを見ると普通のモデルと同じように見えます。teacher部分はtrain_stepに書き加えているだけなのでsummaryには出てきません。

モデルのsummary
>>> model.summary()

Model: "distiller_23"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_38 (InputLayer)        [(None, 1)]               0         
_________________________________________________________________
dense_37 (Dense)             (None, 1)                 2         
_________________________________________________________________
teacher (Functional)         (None, 1)                 2         
=================================================================
Total params: 4
Trainable params: 4
Non-trainable params: 0
_________________________________________________________________

ではweightはどうかというと生徒モデルと教師モデルがセットで入っています。上が生徒モデル、下が教師モデルのweightになっています。

>>> model.weights
[<tf.Variable 'dense_3/kernel:0' shape=(1, 1) dtype=float32, numpy=array([[1.9999967]], dtype=float32)>,
 <tf.Variable 'dense_3/bias:0' shape=(1,) dtype=float32, numpy=array([1.6316116e-06], dtype=float32)>,
 <tf.Variable 'dense_2/kernel:0' shape=(1, 1) dtype=float32, numpy=array([[1.9999983]], dtype=float32)>,
 <tf.Variable 'dense_2/bias:0' shape=(1,) dtype=float32, numpy=array([8.3305986e-07], dtype=float32)>]

>>> model.teacher.weights
[<tf.Variable 'dense_2/kernel:0' shape=(1, 1) dtype=float32, numpy=array([[1.9999983]], dtype=float32)>,
 <tf.Variable 'dense_2/bias:0' shape=(1,) dtype=float32, numpy=array([8.3305986e-07], dtype=float32)>]

生徒モデルだけ取り出して使いたいときにちょっとやりにくいので、次の項で生徒モデルを分けてみたいと思います。

2-3. 生徒モデルを分けて書くカスタムモデル

inputs, outputsを渡してDistillerクラスに生徒モデルをつくるのでなく、生徒モデルをkeras.Modelでつくっておいて、Distillerクラスに組み込むようにしてみます。

この場合はtest_step()も書き換える必要があります。3-2の書き方だとDistillerクラスに生徒モデルが格納されていたのでtest_step()でself(x)したときに生徒モデルが呼ばれて正常に動作していたんですが、こちらでは分けているのでtest_step()で生徒モデルが推定してくれるように書き換えるわけです。

カスタムモデルで学習
class Distiller(tf.keras.Model):
    def __init__(self, teacher, student, **kwargs):
        super(Distiller, self).__init__(**kwargs)
        self.teacher = teacher
        self.student = student

    def train_step(self, data):
        x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
        y_teacher = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            y_pred = self.student(x, training=True)
            loss = self.compiled_loss(y_teacher, y_pred)
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients((grad, var) for (grad, var) in zip(gradients, model.trainable_variables) 
                                       if grad is not None)

        self.compiled_metrics.update_state(y, y_pred, sample_weight)
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
        y_pred = self.student(x, training=False)
        loss = self.compiled_loss(y, y_pred)
        self.compiled_metrics.update_state(y, y_pred, sample_weight)
        return {m.name: m.result() for m in self.metrics}

model = Distiller(teacher, student)
opt = tf.keras.optimizers.SGD(learning_rate=0.1)
model.compile(loss='mse', metrics=('mae'), optimizer=opt, run_eagerly=True)

model.fit(x, y, validation_data=(x, y), epochs=100)
y_pred = model.student.predict(x)
plot(x, y, y_pred)

image.png

これで生徒モデルを分けて扱えるようになりました

3. compiled_lossを追加する場合

元々のModel.compile()が扱えないようなloss functionであっても、Model.compile()をoverrideして追加してやれば諸々を同様にやってくれます。

題材として判別モデルの蒸留をやってみます。回帰モデルと違って教師モデルからsoftmax出力が得られますので、これを生徒モデルが利用できるようにすることで教師モデルより小さい生徒モデルが同等の精度を得られるようにしようというものですが、その為に元々あるloss functionをそのまま使うというわけにはいかなくなるのでちょうど良い題材だと思います。

コードはkeras.ioにあったものをほぼそのまま使っています
https://keras.io/examples/vision/knowledge_distillation/

3-1. データの準備

mnistを使います

# Prepare the train and test dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
1
# Normalize data
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))

x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))

plt.imshow(x_train[0])

image.png

3-2. 教師/生徒モデルの作成

こちらも本来は大きくて精度の高い教師モデルを小さい生徒モデルで再現する為のものですが、書き方の確認をするだけなので同じモデルで作っています

def make_model(name):
    inputs = tf.keras.layers.Input(shape=(28, 28, 1))
    network = tf.keras.layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same")(inputs)
    network = tf.keras.layers.LeakyReLU(alpha=0.2)(network)
    network = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same")(network)
    network = tf.keras.layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same")(network)
    network = tf.keras.layers.Flatten()(network)
    outputs = tf.keras.layers.Dense(10)(network)
    model = tf.keras.models.Model(inputs, outputs, name=name)
    return model

teacher = make_model('teacher')
student = make_model('student')
teacher.summary()
student.summary()

teacher.compile(optimizer=tf.keras.optimizers.Adam(),
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
teacher.fit(x_train, y_train, epochs=5)
print(teacher.evaluate(x_test, y_test))

3-3. カスタムモデルをつくって学習

Model.compile()にself.student_loss_fnとself.distillation_loss_fnを追加することでupdate_stateが掛かるようにしています。tensorflowの公式ページではModel.metrics()を書くことでupdate_state()が掛かるようにしているので書き方が違うんですが、どちらが良いのかよく分からないので知っている人がいたら教えて下さい。

class Distiller(tf.keras.Model):
    def __init__(self, teacher, student, **kwargs):
        super(Distiller, self).__init__(**kwargs)
        self.teacher = teacher
        self.student = student

    def compile(self, optimizer, metrics, student_loss_fn, distillation_loss_fn, alpha=0.1, temperature=3, **kwargs):
        super(Distiller, self).compile(optimizer=optimizer, metrics=metrics, **kwargs)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)

        # 教師モデルで推定
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            # 生徒モデルで推定
            student_predictions = self.student(x, training=True)

            # lossを算出
            student_loss = self.student_loss_fn(y, student_predictions)
            distillation_loss = self.distillation_loss_fn(
                tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                tf.nn.softmax(student_predictions / self.temperature, axis=1),
            )
            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

        # 勾配を算出
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # optimizerでweightを更新
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # compiled_metricsを更新
        self.compiled_metrics.update_state(y, student_predictions)

        # metricsを算出して返す
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
        x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
        y_pred = self.student(x, training=False)
        student_loss = self.student_loss_fn(y, y_pred)
        self.compiled_metrics.update_state(y, y_pred, sample_weight)

        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results


distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
    optimizer=tf.keras.optimizers.Adam(),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=tf.keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=10,
    run_eagerly=True
)

distiller.fit(x_train, y_train, epochs=3)
distiller.evaluate(x_test, y_test)

4. Model.metrics()に書くやり方 + カスタムレイヤーの追加

もうちょっとややこしい作例としてカスタムレイヤーの追加が必要になるVAEをやってみます

カスタムレイヤーの作り方についてはkeras.ioに記事がありました
https://keras.io/guides/making_new_layers_and_models_via_subclassing/

VAEのコードもkeras.ioにありました
https://keras.io/search.html?q=vae

VAEってなんぞという方は以下の記事が分かりやすいと思います
https://qiita.com/kenmatsu4/items/b029d697e9995d93aa24

4-1. データの準備

mnistのデータを使いますがラベルは必要ないので画像のみ、分ける必要がないので学習/テストデータを1つにまとめちゃいます。

(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255
plt.imshow(mnist_digits[0])
plt.show()

image.png

4-2. サンプリングレイヤーをつくる

潜在変数zで表現されるlatent_dim次元の正規分布をサンプリングするレイヤーを作成します

class Sampling(tf.keras.layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

4-3. encoderをつくる

画像からCNNで次元を落としていって、最終的にlatent_dim次元の正規分布まで次元削減するモデルをつくります。先程のSamplingレイヤーはencoderの最後に入ります。

latent_dim = 2

def make_encoder(latent_dim):
    encoder_inputs = tf.keras.Input(shape=(28, 28, 1))
    x = tf.keras.layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
    x = tf.keras.layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(16, activation="relu")(x)
    z_mean = tf.keras.layers.Dense(latent_dim, name="z_mean")(x)
    z_log_var = tf.keras.layers.Dense(latent_dim, name="z_log_var")(x)
    z = Sampling()([z_mean, z_log_var])
    encoder = tf.keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
    return encoder

encoder = make_encoder(latent_dim)

encoder.summary()
tf.keras.utils.plot_model(encoder, show_shapes=True, expand_nested=True)
Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_13 (InputLayer)           [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 14, 14, 32)   320         input_13[0][0]                   
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 7, 7, 64)     18496       conv2d_10[0][0]                  
__________________________________________________________________________________________________
flatten_5 (Flatten)             (None, 3136)         0           conv2d_11[0][0]                  
__________________________________________________________________________________________________
dense_11 (Dense)                (None, 16)           50192       flatten_5[0][0]                  
__________________________________________________________________________________________________
z_mean (Dense)                  (None, 2)            34          dense_11[0][0]                   
__________________________________________________________________________________________________
z_log_var (Dense)               (None, 2)            34          dense_11[0][0]                   
__________________________________________________________________________________________________
sampling_4 (Sampling)           (None, 2)            0           z_mean[0][0]                     
                                                                 z_log_var[0][0]                  
==================================================================================================
Total params: 69,076
Trainable params: 69,076
Non-trainable params: 0
__________________________________________________________________________________________________

image.png

4-4. decoderをつくる

encoderとは逆にlatent_dim次元の正規分布から元画像と同じ次元まで次元を増やしていくdecoderモデルをつくります

def make_decoder(latent_dim):
    latent_inputs = tf.keras.Input(shape=(latent_dim,))
    x = tf.keras.layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
    x = tf.keras.layers.Reshape((7, 7, 64))(x)
    x = tf.keras.layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
    x = tf.keras.layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
    decoder_outputs = tf.keras.layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
    decoder = tf.keras.Model(latent_inputs, decoder_outputs, name="decoder")
    return decoder

decoder = make_decoder(latent_dim)

decoder.summary()
tf.keras.utils.plot_model(decoder, show_shapes=True, expand_nested=True)
Model: "decoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_14 (InputLayer)        [(None, 2)]               0         
_________________________________________________________________
dense_12 (Dense)             (None, 3136)              9408      
_________________________________________________________________
reshape_6 (Reshape)          (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_transpose_18 (Conv2DT (None, 14, 14, 64)        36928     
_________________________________________________________________
conv2d_transpose_19 (Conv2DT (None, 28, 28, 32)        18464     
_________________________________________________________________
conv2d_transpose_20 (Conv2DT (None, 28, 28, 1)         289       
=================================================================
Total params: 65,089
Trainable params: 65,089
Non-trainable params: 0
_________________________________________________________________

image.png

4-5. VAEモデルをつくる

先程の蒸留とは違いlossをModel.compile()に渡すのでなく、Model.metrics()のreturnで返すやり方になっています。compiled_metricsであればepochごとの状態のクリアを自動でやってくれますがこの作例ではcompileしていません。そういう書き方をする場合はModel.metricsのreturnで追加したlossを返すようにすればcompileしたときと同様、epochごとの状態クリアをやってくれるようになります。

tensorflow公式に紹介されているのがこのやり方なので、どちらかというとこちらの方がスタンダードな書き方なのかな?

class VAE(tf.keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = tf.keras.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    tf.keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

encoder = make_encoder(latent_dim)
decoder = make_decoder(latent_dim)

vae = VAE(encoder, decoder)
vae.compile(optimizer=tf.keras.optimizers.Adam())
vae.fit(mnist_digits, epochs=30, batch_size=128)

4-6. 潜在変数を可視化する

潜在変数を2次元にしてあるので、2次元にプロットすることが出来ます。そこで、潜在変数に対してどんな画像がdecodeされるかを可視化してみます。

def plot_latent_space(vae, n=30, figsize=15):
    # display a n*n 2D manifold of digits
    digit_size = 28
    scale = 1.0
    figure = np.zeros((digit_size * n, digit_size * n))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = vae.decoder.predict(z_sample)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[
                i * digit_size : (i + 1) * digit_size,
                j * digit_size : (j + 1) * digit_size,
            ] = digit

    plt.figure(figsize=(figsize, figsize))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap="Greys_r")
    plt.show()

plot_latent_space(vae)

それっぽいですね
image.png

4-7. encoder

学習に使ったデータをラベルごとにencodeするとどの辺に配置されるのかプロットします。

def plot_label_clusters(vae, data, labels):
    # display a 2D plot of the digit classes in the latent space
    z_mean, _, _ = vae.encoder.predict(data)
    plt.figure(figsize=(12, 10))
    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=labels)
    plt.colorbar()
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.show()


(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
x_train = np.expand_dims(x_train, -1).astype("float32") / 255

plot_label_clusters(vae, x_train, y_train)

image.png

5. 複数のモデルをそれぞれ学習する

最後に複数のモデルを交互に学習する例としてGANをやってみます

keras.ioにConditional GANの例があったのでこれをやっていきます
https://keras.io/examples/generative/conditional_gan/

GANがなんだか分からない人はこの辺を読みましょう
https://qiita.com/mm_918/items/c6bf085e7618b8af3d98

5-1. データの準備

定数の指定
batch_size = 64
num_channels = 1
num_classes = 10
image_size = 28
latent_dim = 128

generator_in_channels = latent_dim + num_classes
discriminator_in_channels = num_channels + num_classes
print(generator_in_channels, discriminator_in_channels)

28 x 28 pixelのグレースケール画像と0~9の10クラスのどれであるのかをone-hotで表現したラベルをtf.data.Datasetにしています

データの準備
# We'll use all the available examples from both the training and test
# sets.
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_labels = np.concatenate([y_train, y_test])

# Scale the pixel values to [0, 1] range, add a channel dimension to
# the images, and one-hot encode the labels.
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
all_labels = tf.keras.utils.to_categorical(all_labels, 10)

# Create tf.data.Dataset.
dataset = tf.data.Dataset.from_tensor_slices((all_digits, all_labels))
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)

print(f"Shape of training images: {all_digits.shape}")
print(f"Shape of training labels: {all_labels.shape}")
確認
k = np.random.randint(0, all_digits.shape[0], 1)[0]
print(k)
plt.imshow(all_digits[k])
plt.title(all_labels[k])
plt.show()

image.png

5-2. discriminatorの作成

mnistで普通に判別をやる場合は画像と0~9の正解ラベルの2つの情報を使って学習をしますが、discriminatorは画像と0~9の正解ラベルからその画像がgeneratorが乱数からつくったフェイクなのか本物なのかを判別するように学習します。

28x28 pixelの画像が来るので(28, 28, 1)が入力されるのかと思ったら、discriminator_in_channelが11なので(28, 28, 11)で入力されるようになっています。残りの10次元はなんなのかというと、one-hotのラベルがここに入っています。ラベルを担当する28x28の行列が10枚あって、正解ラベルに該当する行列は値がすべて1、正解ラベル以外の行列9枚は値がすべて0になっているものを受け取ります。

無駄に行列が大きいようにも思えますが、畳み込み層がラベル情報を上手く扱えるようにするには畳み込みの範囲内に常に情報がある状態にする為にこうなっているんじゃないかと思います。

def make_discriminator():
    discriminator = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer((28, 28, discriminator_in_channels)),
            tf.keras.layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
            tf.keras.layers.LeakyReLU(alpha=0.2),
            tf.keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
            tf.keras.layers.LeakyReLU(alpha=0.2),
            tf.keras.layers.GlobalMaxPooling2D(),
            tf.keras.layers.Dense(1),
        ],
        name="discriminator",
    )
    return discriminator

discriminator = make_discriminator()
tf.keras.utils.plot_model(discriminator, show_shapes=True, expand_nested=True)

5-3. generatorをつくる

latent_dim次元の乱数と0~9のone-hotラベルからフェイク画像をつくるモデルをつくります。画像だけでなくラベルをつけて学習することで、ラベルを指定して画像生成が出来るようになります。

この作例ではlatent_dimが128次元なので、0~9のone-hotラベルの10次元と合わせて138次元の入力から作成します。

def make_generator():
    generator = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer((generator_in_channels,)),
            # We want to generate 128 + num_classes coefficients to reshape into a
            # 7x7x(128 + num_classes) map.
            tf.keras.layers.Dense(7 * 7 * generator_in_channels),
            tf.keras.layers.LeakyReLU(alpha=0.2),
            tf.keras.layers.Reshape((7, 7, generator_in_channels)),
            tf.keras.layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
            tf.keras.layers.LeakyReLU(alpha=0.2),
            tf.keras.layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
            tf.keras.layers.LeakyReLU(alpha=0.2),
            tf.keras.layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
        ],
        name="generator",
    )
    return generator

generator = make_generator()
tf.keras.utils.plot_model(generator, show_shapes=True, expand_nested=True)

5-4. GANモデルをつくる

1つのbatchに対してdiscriminatorの学習、generatorの学習を続けて行います。それぞれのモデルの構造が違う為にweightの次元も違いますし、learning_rateなどの最適値も違ってくるでしょうからoptimizerも分けているという事だと思います。

class ConditionalGAN(tf.keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super(ConditionalGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.gen_loss_tracker = tf.keras.metrics.Mean(name="generator_loss")
        self.disc_loss_tracker = tf.keras.metrics.Mean(name="discriminator_loss")

    @property
    def metrics(self):
        return [self.gen_loss_tracker, self.disc_loss_tracker]

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(ConditionalGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, data):
        # Unpack the data.
        real_images, one_hot_labels = data

        # Add dummy dimensions to the labels so that they can be concatenated with
        # the images. This is for the discriminator.
        image_one_hot_labels = one_hot_labels[:, :, None, None]  # (64, 10, 1, 1)
        image_one_hot_labels = tf.repeat(
            image_one_hot_labels, repeats=[image_size * image_size]
        )
        image_one_hot_labels = tf.reshape(
            image_one_hot_labels, (-1, num_classes, image_size, image_size)
        )  # (64, 10, 28, 28)
        image_one_hot_labels = tf.transpose(image_one_hot_labels, (0, 2, 3, 1))  # (64, 28, 28, 10)

        # Sample random points in the latent space and concatenate the labels.
        # This is for the generator.
        batch_size = tf.shape(real_images)[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        random_vector_labels = tf.concat(
            [random_latent_vectors, one_hot_labels], axis=1
        )  # (64, 138)

        # Decode the noise (guided by labels) to fake images.
        generated_images = self.generator(random_vector_labels)  # (64, 28, 28, 1)

        # Combine them with real images. Note that we are concatenating the labels
        # with these images here.
        fake_image_and_labels = tf.concat([generated_images, image_one_hot_labels], -1)  # (64, 28, 28, 11)
        real_image_and_labels = tf.concat([real_images, image_one_hot_labels], -1)  # (64, 28, 28, 11)
        combined_images = tf.concat(
            [fake_image_and_labels, real_image_and_labels], axis=0
        )  # (128, 28, 28, 11)

        # Assemble labels discriminating real from fake images.
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )  # (128, 1)

        # Train the discriminator.
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)  # (128, 1)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )

        # Sample random points in the latent space.
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))  # (64, 128)
        random_vector_labels = tf.concat(
            [random_latent_vectors, one_hot_labels], axis=1
        )  # (64, 138)

        # Assemble labels that say "all real images".
        misleading_labels = tf.zeros((batch_size, 1))  # (64, 1)

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:
            fake_images = self.generator(random_vector_labels)  # (64, 28, 28, 1)
            fake_image_and_labels = tf.concat([fake_images, image_one_hot_labels], -1)  # (64, 28, 28, 11)
            predictions = self.discriminator(fake_image_and_labels)  # (64, 1)
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        # Monitor loss.
        self.gen_loss_tracker.update_state(g_loss)
        self.disc_loss_tracker.update_state(d_loss)
        return {
            "g_loss": self.gen_loss_tracker.result(),
            "d_loss": self.disc_loss_tracker.result(),
        }

5-5. 学習

その時点でgeneratorがどんな画像をつくるのか確認する関数をつくっておきます

generatorがつくる画像を確認
def check_generator(generator, n=3):
    _, one_hot_labels = next(iter(dataset))

    image_one_hot_labels = one_hot_labels[:, :, None, None]
    image_one_hot_labels = tf.repeat(
        image_one_hot_labels, repeats=[image_size * image_size]
    )
    image_one_hot_labels = tf.reshape(
        image_one_hot_labels, (-1, num_classes, image_size, image_size)
    )  # (64, 10, 28, 28)
    image_one_hot_labels = tf.transpose(image_one_hot_labels, (0, 2, 3, 1))  # (64, 28, 28, 10)

    random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))  # (64, 128)

    random_vector_labels = tf.concat(
        [random_latent_vectors, one_hot_labels], axis=1
    )  # (64, 138)

    generated_images = generator(random_vector_labels)  # (64, 28, 28, 1)

    fake_image_and_labels = tf.concat([generated_images, image_one_hot_labels], -1)

    for k in np.random.randint(0, batch_size, n):

        label = np.where(image_one_hot_labels[k][0, 0])[0][0]
        print('k: {}, label: {}'.format(k ,label))
        fake_image_and_label = fake_image_and_labels[k]
        print('fake_image_and_label: {}'.format(fake_image_and_label.shape))

        plt.imshow(fake_image_and_label[:, :, 0])
        plt.show()

        fig, ax = plt.subplots(2, 5, figsize=(10, 4))
        for pos, channel in enumerate(range(1, 11)):
            im = fake_image_and_label[:, :, channel]
            i, j = pos // 5, pos % 5
            ax[i, j].imshow(im, vmin=0, vmax=1)
            ax[i, j].set_xticks([])
            ax[i, j].set_yticks([])
        plt.show()

check_generator(generator)
学習する
discriminator = make_discriminator()
generator = make_generator()

cond_gan = ConditionalGAN(
    discriminator=discriminator, generator=generator, latent_dim=latent_dim
)
cond_gan.compile(
    d_optimizer=tf.keras.optimizers.Adam(learning_rate=0.0003),
    g_optimizer=tf.keras.optimizers.Adam(learning_rate=0.0003),
    loss_fn=tf.keras.losses.BinaryCrossentropy(from_logits=True),
)

total_epochs = 0
print('#####################################')
print('### epochs: {}\n'.format(total_epochs))

check_generator(generator, n=3)

for epochs in [10, 10]:
    cond_gan.fit(dataset, epochs=epochs)

    print('#####################################')
    print('### epochs: {}\n'.format(total_epochs))
    check_generator(generator, n=3)

補機のgeneratorの出力です。入力もweightの初期値も乱数なので、当然ながら出てくるのは乱数です
image.png
10 epoch学習した後のgeneratorです。ちょっとそれっぽくなってきました
image.png
40 epoch学習してかなりそれっぽくなっています
image.png

6. デバッグするときはeager modeにしよう

Model.compile()すると高速に学習できるようにcompileしてくれるわけですが、その結果エラー発生箇所が分かりにくくなりますが、デバッグする時の為にcompileしないで学習することも出来るようになっています。

run_eagerly=Trueにする
model.compile(loss='mse', metrics='mse', optimizer='adam', run_eagerly=True)

7. まとめ

いきなりややこしい学習を書くと動かない原因が分からずつらいので、シンプルなところから1つずつ足していく書き方が良いんじゃないかと思います。

レッツトライ

参考記事

https://qiita.com/tonouchi510/items/1203bfe8e6bdb61d9902
https://zenn.dev/koshian2/articles/5cdc96f2feeda8
https://twitter.com/fchollet/status/1348333222791319554
https://keras.io/guides/making_new_layers_and_models_via_subclassing/
https://qiita.com/tatsuya11bbs/items/7d7a2c920730ae0c592a
https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch
https://qiita.com/shinmura0/items/e51565960648dccf8486

13
10
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
10