LoginSignup
4
4

More than 1 year has passed since last update.

VAEの理論と実装

Posted at

はじめに

vaeをTensorflow2で実装してみました.

データセット、ライブラリはそれぞれ以下を使用しています.
ライブラリ: Tensorflow2
データセット: FashionMnist

本記事で紹介しているコードは以下で実行できます.
https://colab.research.google.com/drive/1B-rtbHOatjavv8V8lfPYs9ymZ0o1sUL9?usp=sharing

VAEとは

端的に言うと、オートエンコーダ型の生成モデルです.

生成モデルとは

データが何かしら確率的な分布から生成されていると仮定し、
データが生成される元の確率分布を学習するモデルのことをいいます.

オートエンコーダとは

データをより少ない次元に変換する次元圧縮器で、以下のような構造をとります.
入力値をエンコーダーで潜在変数に圧縮し、潜在変数からデコーダーで元の次元に復元します.
ae.png

学習のゴール

VAEでは、データが正規分布から生成されていると仮定します.
正規分布は期待値と分散によって決まるのでした.
データが生成される元の変数を潜在変数と呼び、VAEでは潜在変数の期待値と分散を学習することがゴールになります.

VAEの構造

どうやって学習するのかというと、以下の形をとります.

  1. エンコーダは期待値と分散の2つを出力.
  2. 期待値と分散から潜在変数をサンプリング.
  3. デコーダで潜在変数から元の次元を復元する

デコーダは通常のオートエンコーダと同じですが、
エンコーダの出力と、そこからサンプリング処理があることが、オートエンコーダとの大きな違いになります.

vae.png

エンコーダの実装

エンコーダは元のデータから期待値と分散の2つを出力します.
分散は必ず0以上になる必要がありますが、全結合層の出力は必ず0以上になるわけではないので、エンコーダーから出力される分散は対数を取ったものを出力します.

class Encoder(layers.Layer):

    def __init__(self, latent_dim=10):
        super().__init__()
        self.c1 = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")
        self.c2 = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")
        self.f1 = layers.Flatten()
        self.d1 = layers.Dense(16, activation="relu")
        self.d2 = layers.Dense(latent_dim, name="u")
        self.d3 = layers.Dense(latent_dim, name="var")

    def call(self, x):
        x = self.c1(x)
        x = self.c2(x)
        x = self.f1(x)
        x = self.d1(x)
        u = self.d2(x)
        log_var = self.d3(x)
        return u, log_var

サンプリング層の実装

エンコーダから出力した期待値と分散から潜在変数をサンプリングしています.
誤差逆伝播で学習するには微分可能でないといけないので、以下のような式で擬似的なサンプリングをします
この方法をreparameterization trickといいます.

潜在変数 = 期待値(u) + 標準偏差(σ) * 乱数

class Sampling(layers.Layer):

    def call(self, u, log_var):
        u_shape = tf.shape(u)
        epsilon = tf.keras.backend.random_normal(shape=(u_shape[0], u_shape[1]))
        return u + tf.exp(0.5 * log_var) * epsilon

デコーダの実装

デコーダでは、潜在変数から元のデータを復元します.
ここはオートエンコーダと同じです.


class Decoder(layers.Layer):

    def __init__(self):
        super().__init__()
        self.d1 = layers.Dense(7 * 7 * 64, activation="relu")
        self.r1 = layers.Reshape((7, 7, 64))
        self.c1 = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")
        self.c2 = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")
        self.c3 = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")

    def call(self, x):
        x = self.d1(x)
        x = self.r1(x)
        x = self.c1(x)
        x = self.c2(x)
        y = self.c3(x)
        return y

損失関数の実装

VAEの損失関数は2つの項からなります.

VAEの損失 = 再構成ロス + KLダイバージェンス

再構成ロス

  • 生成されたデータと元データの差を計算
  • 今回は交差エントロピーを採用
  • 大体は2乗誤差か交差エントロピーのどちらかが使われるが、どちらにするかはケースバイケース.

KLダイバージェンス

  • 2つの確率分布間の距離
  • VAEでの定義は以下 
    • Q: 仮定している正規分布(平均0 分散1)
    • P: 実際のエンコーダの出力
  • エンコーダの出力を平均0,分散1に正規分布に近づけることが目標
  • 定義をもとに計算すると最終的に以下の式になる

$$
{D_{KL}(P | Q)
= - \frac{1}{2} \sum_{j=1}^J (1 + \log((\sigma_j)^2) - (\mu_j)^2 - (\sigma_j)^2)
}
$$

class VAELossFunction:

    def __call__(self, u, log_var, z, y, y_true) -> (float, float, float):
        # 再構成ロス
        rc_loss = tf.reduce_mean(
            tf.reduce_sum(
                tf.keras.losses.binary_crossentropy(y_true, y),
                axis=(1, 2)
            )
        )
        # KLダイバージェンス
        kl_loss = tf.reduce_mean(
            tf.reduce_sum(
                -0.5 * (1 + log_var - tf.square(u) - tf.exp(log_var)),
                axis=1
            ))
        return rc_loss, kl_loss, rc_loss + kl_loss

モデル全体の実装

モデル全体の計算の流れを記述しています.

  1. 入力値をエンコーダーに流して期待値と分散を取得
  2. 期待値と分散をサンプリング層に渡して潜在変数を取得
  3. 潜在変数からデコーダで元の次元に復元

💡Point

  • 訓練用のメソッド(train_step)の出力はVAEResultとし定義
  • 学習時の計算はtf.functionでデコレートしたtrain_step_functionで定義
    • eager modeで動的に計算グラフを作成するため
    • tf.functionはtf.Tensorに変換できる戻り値である必要があるので、train_stepとは別メソッドで定義
  • 予測用にバッチサイズなしの画像を受け取って潜在変数と復元データを返すメソッドを作成(predict)
@dataclass
class VAEResult:
    rc_loss: float
    kl_loss: float
    total_loss: float


class VAE(Model):

    def __init__(self, optimizer):
        super().__init__()
        self.encoder = Encoder(latent_dim=2)
        self.decoder = Decoder()
        self.sampling = Sampling()
        self.loss_function = VAELossFunction()
        self.optimizer = optimizer

    def call(self, x: tf.Tensor):
        u, var = self.encoder(x)
        z = self.sampling(u, var)
        y = self.decoder(z)
        return u, var, z, y

    @tf.function
    def train_step_function(self, y_true: tf.Tensor):
        with tf.GradientTape() as tape:
            u, var, z, y = self.call(y_true)
            rc_loss, kl_loss, total_loss = self.loss_function(u, var, z, y, y_true)
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        return rc_loss, kl_loss, total_loss

    def train_step(self, y_true: tf.Tensor) -> VAEResult:
        rc_loss, kl_loss, total_loss = self.train_step_function(y_true)
        return VAEResult(
            rc_loss=rc_loss,
            kl_loss=kl_loss,
            total_loss=total_loss
          )

    def test_step(self, y_test: tf.Tensor) -> (tf.Tensor, np.array):
        u, var, z, y = self.call(y_test)
        rc_loss, kl_loss, total_loss = self.loss_function(u, var, z, y, y_test)
        return VAEResult(
            rc_loss=rc_loss,
            kl_loss=kl_loss,
            total_loss=total_loss
          )

    def predict(self, image):
        tensor = tf.convert_to_tensor(image, dtype=tf.float32)
        tensor = tf.expand_dims(tensor, axis=0)
        u, var, z, y = self.call(tensor)
        z = tf.squeeze(z, [0])
        y = tf.squeeze(y, [0])
        return z, y

メトリクスの実装

学習時のメトリクスを定義してます

  • update_on_trainでモデルで計算された各損失の平均を記録
  • displayでメトリクスの情報を表示
  • resetでメトリクスの状態をクリア
import tensorflow as tf
from tensorflow.keras import metrics


class VAETrainMetrics:

    def __init__(self):
        self.rc_loss = metrics.Mean(name='train_rc_loss')
        self.kl_loss = metrics.Mean(name='train_kl_loss')
        self.total_loss = metrics.Mean(name='train_total_loss')

    def update_on_train(self, result: VAEResult):
        self.rc_loss.update_state(result.rc_loss)
        self.kl_loss.update_state(result.kl_loss)
        self.total_loss.update_state(result.total_loss)

    def update_on_test(self, result: VAEResult):
        pass

    def reset(self):
        self.rc_loss.reset_states()
        self.kl_loss.reset_states()
        self.total_loss.reset_states()

    def display(self, epoch: int):
        template = 'Epoch {}, RC Loss: {:.2g}, KL Loss: {:.2g}, Total Loss: {:.2g}'
        print(
            template.format(
                epoch,
                self.rc_loss.result(),
                self.kl_loss.result(),
                self.total_loss.result()
            )
        )

学習器の実装

いつも通り、学習のイテレーションを記載してます.
以下のステップをエポック数だけ繰り返します.

  1. バッチ化したデータをループで回す
  2. モデル渡して損失を取得
  3. 損失をメトリクスに記録
class VAETrainer:

    def __init__(self,  model,  metrics):
        self.model = model
        self.metrics = metrics

    def fit(self, dataset,  epochs: int = 3):
        """ Train mdoel by epochs
        """
        for e in range(1, epochs + 1):

            # train model
            for images in dataset.train_loop():
                result = self.model.train_step(images)

            # test model
            for images in dataset.test_loop():
                result = self.model.test_step(images)
                self.metrics.update_on_test(result)

            # show metrics
            self.metrics.display(e)
            self.metrics.reset()

データセットの用意

お馴染みのFashionMnistを使用しました.
https://github.com/zalandoresearch/fashion-mnist/blob/master/README.ja.md
fashion_mnist_5x5.png

学習の実行

とりあえず、10エポックほど学習してみます.

from tensorflow.keras import optimizers

# モデル初期化
vae = VAE(optimizer=optimizers.Adam())

# メトリクス初期化
vae_metrics = VAETrainMetrics()

# 学習用クラス初期化
vae_trainer = VAETrainer(vae, vae_metrics)

# データセットクラス初期化
fashion_mnist_dataset = FashionMnistDataset()

# 学習実行
vae_trainer.fit(fashion_mnist_dataset, epochs=10)

実行結果

Epoch 1, RC Loss: 2.9e+02, KL Loss: 7.5, Total Loss: 3e+02
Epoch 2, RC Loss: 2.7e+02, KL Loss: 7.5, Total Loss: 2.8e+02
Epoch 3, RC Loss: 2.7e+02, KL Loss: 7.4, Total Loss: 2.7e+02
Epoch 4, RC Loss: 2.6e+02, KL Loss: 7.3, Total Loss: 2.7e+02
Epoch 5, RC Loss: 2.6e+02, KL Loss: 7.1, Total Loss: 2.7e+02
Epoch 6, RC Loss: 2.6e+02, KL Loss: 7, Total Loss: 2.7e+02
Epoch 7, RC Loss: 2.6e+02, KL Loss: 6.8, Total Loss: 2.7e+02
Epoch 8, RC Loss: 2.6e+02, KL Loss: 6.7, Total Loss: 2.7e+02
Epoch 9, RC Loss: 2.6e+02, KL Loss: 6.6, Total Loss: 2.7e+02
Epoch 10, RC Loss: 2.6e+02, KL Loss: 6.6, Total Loss: 2.7e+02

学習結果

適当に組むましたが、割とちゃんと学習されました.
多少ぼやけてるのは、VAEだと致し方ない部分です.
より鮮明に学習したい場合は、GANなどを使う必要がありそうです.

元画像
vae_org.png

復元した画像
vae_result.png

おわりに

tensorflow1系で実装したことは以前ありますが、2系ではなかったので、実装してみました.
非常にスッキリかけて良いですね.
次はGANや自然言絵処理への応用など、より発展した実装をしたくなりました.

4
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
4