5
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

離散空間のReparameterizationTrick(GumbelMaxTrick/StraightGradients/VQ-VAE)

Posted at

離散空間の勾配を計算したい状況があったのでまとめてみました。
タイトルの通り3種類の方法をとり上げますが、あまり情報がなく他にもあるかもしれません…。

1. ReparameterizationTrick(VAE)

まずは連続空間の ReparameterizationTrick を見ていきます。
これは、確率分布を微分可能な関数に置き換えるテクニックで、VAE(Variational Autoencoder)でよく登場するテクニックです。

VAEを簡単に言うと、教師なし学習における特徴抽出の一種で、特徴を標準正規分布上で表現できるように特徴抽出します。
特徴は正規分布に従うので連続空間となります。

これを愚直に表現すると以下の問題が発生します。

aa-ページ2.drawio.png

これを解決する手法が ReparameterizationTrick で、正規分布上でも勾配を流すことができるテクニックです。

aa-ページ3.drawio.png

MNISTによるサンプルコード

参考:TensorFlow > 学ぶ > TensorFlow Core > チュートリアル > 畳み込み変分オートエンコーダ

MNISTで実際に実装してみます。
また以降はこのコードをベースにモデルのみを変更して同じコードを使いまわしていきます。

VAEを正確に実装することが目的ではないので以下の違いがある点は注意してください。

  • Conv2D層はありません。Dense層が1層のみで作成しています。
  • 正則化項に該当するKL損失は省略しています。(なのでこれは標準正規分布には従いません)
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.datasets import mnist

kl = keras.layers

# データの読み込みと前処理
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images = train_images.reshape((60000, 28, 28, 1)).astype("float32") / 255
test_images = test_images.reshape((10000, 28, 28, 1)).astype("float32") / 255


class VAE(keras.Model):
    def __init__(self):
        super().__init__()
        self.z_size = 10

        # --- encoder
        self.enc_layers = [
            kl.Flatten(),
            kl.Dense(128, activation="relu"),
        ]
        self.mean_layer = kl.Dense(self.z_size)
        self.log_stddev_layer = kl.Dense(self.z_size)

        # --- decoder
        self.dec_layers = [
            kl.Dense(128, activation="relu"),
            kl.Dense(28 * 28 * 1),
            kl.Reshape((28, 28, 1)),
        ]

    def call(self, x):
        mean, stddev = self.encode(x)

        # tf.random.normalはshapeがNoneだとエラーになる
        if mean.shape[0] is None:
            z = mean + stddev * 0
        else:
            # --- reparameterization trick
            e = tf.random.normal(shape=mean.shape)
            z = mean + stddev * e

        return self.decode(z)

    def sample(self, x):
        # 乱数を使わない場合は平均をそのまま使う
        mean, stddev = self.encode(x)
        return self.decode(mean)

    def encode(self, x):
        for h in self.enc_layers:
            x = h(x)
        mean = self.mean_layer(x)
        log_stddev = self.log_stddev_layer(x)
        stddev = tf.math.exp(log_stddev)
        return mean, stddev

    def decode(self, x):
        for h in self.dec_layers:
            x = h(x)
        return x


model = VAE()
model.compile(optimizer="adam", loss="mse")

# モデルの訓練
model.fit(train_images, train_images, epochs=10, batch_size=64)
test_loss = model.evaluate(test_images, test_images)
print("Test accuracy:", test_loss)  # Test accuracy: 0.020644141361117363

# 表示
pred_images = model.sample(test_images[:8])
fig = plt.figure(figsize=(4, 4))
for i in range(8):
    plt.subplot(4, 4, i + 1)
    plt.imshow(test_images[i, :, :, 0], cmap="gray")
    plt.axis("off")
for i in range(8):
    plt.subplot(4, 4, 8 + i + 1)
    plt.imshow(pred_images[i, :, :, 0], cmap="gray")
    plt.axis("off")
plt.show()

結果は以下です。

Figure_1.png

上2段が入力で下2段が出力結果です。
学習は出来ていそうですね。

2. 学習できないCategoricalVAE

VAEでは特徴を正規分布(連続空間)と仮定しましたが、カテゴリカル分布(離散空間)と仮定して作成します。

aa-ページ4.drawio.png

カテゴリカル分布なので例えば特徴数を10とすれば0~9の値をとります。
どの値を取るかはsoftmaxで確率的に表現し、出力側では確率で決まった値をonehot化して渡します。

コードは以下です。

class CategoricalVAE(keras.Model):
    def __init__(self):
        super().__init__()
        self.z_size = 10

        # --- encoder
        self.enc_layers = [
            kl.Flatten(),
            kl.Dense(128, activation="relu"),
        ]
        self.logits_layer = kl.Dense(self.z_size)

        # --- decoder
        self.dec_layers = [
            kl.Dense(128, activation="relu"),
            kl.Dense(28 * 28 * 1),
            kl.Reshape((28, 28, 1)),
        ]

    def call(self, x):
        logits = self.encode(x)
        sample = tf.random.categorical(logits, 1)
        z = tf.one_hot(tf.squeeze(sample, 1), self.z_size)
        return self.decode(z)

    def sample(self, x):
        return self.call(x)

    def encode(self, x):
        for h in self.enc_layers:
            x = h(x)
        return self.logits_layer(x)

    def decode(self, x):
        for h in self.dec_layers:
            x = h(x)
        return x

# --- modelを変更するだけで実行できます。
#model = VAE()
model = CategoricalVAE()

学習は出来ませんが、エラーなく実行することは出来ました。
学習結果は以下です。

Figure_2.png

勾配が流れないので学習できていませんね。
ちなみに以下警告も出力され、いくつかの変数で勾配が流れていないことを指摘されます。

WARNING:tensorflow:Gradients do not exist for variables ['categorical_vae/dense/kernel:0', 'categorical_vae/dense/bias:0', 'categorical_vae/dense_1/kernel:0', 'categorical_vae/dense_1/bias:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?

3. Gumbel-Max Trick

離散的なReparameterizationTrickを検索したら最初に見つけた手法です。
ざっくりいうとGumbel-Softmaxを使うことでReparameterizationTrickを行う手法です。

参考
Gumbel-Max Trick(ガンベル最大トリック)を理解する | 楽しみながら理解するAI・機械学習入門
Categorical Reparameterization with Gumbel-Softmax | ご注文は機械学習ですか?

Gumbel-Softmax分布

SoftmaxとGumbel-Softmaxは以下です。

\text{softmax}(x)_i = \frac{e^{x_i}}{\sum_{j=1}^{K} e^{x_j}}
\text{gumbel_softmax}(x, \tau)_i = \frac{e^{(x_i + g_i)/\tau}}{\sum_{j=1}^{K} e^{(x_j + g_j)/\tau}}

$g$ はGumbel関数からのサンプルを表し、$\tau$ は乱数を制御する温度パラメータです。
このGumbel分布ですが、各要素において以下の$z$が最大値をとる確率はsoftmaxの確率と一致します。

$$
z_i = x_i + g_i
$$

$g$ はGumbel分布に従う乱数で、一様乱数をGumbel分布の逆関数に入れることで取得できます。

\begin{align}
u \sim Uniform(0,1) \\
g = -ln(-ln(u))
\end{align}

実際に見てみます。

import matplotlib.pyplot as plt
import numpy as np


def softmax(x):
    exp_x = np.exp(x - np.max(x))
    return exp_x / np.sum(exp_x, axis=0)


def gumbel_inverse(x):
    return -np.log(-np.log(x))


x = np.array([1.2, 2.3, 1.4, 3.3, 0.1])
num_samples = 100000

# --- random softmax
y_softmax = np.random.choice(len(x), size=num_samples, p=softmax(x))

# --- random gumbel
y_gumbel = []
for _ in range(num_samples):
    rnd = np.random.uniform(size=(len(x)))
    z = x + gumbel_inverse(rnd)
    y_gumbel.append(np.argmax(z))

# --- plot
plt.hist([y_softmax, y_gumbel], label=["softmax", "gumbel"])
plt.legend()
plt.show()

Figure_3.png

見事に一致していますね。
最後にサンプリングと勾配で使う確率ベクトルの式を書いておきます。

・サンプリング
$$
z = \text{onehot}(\underset{i}{\text{argmax}}(ln(x_i) + g_i))
$$

・勾配で使う確率ベクトル
$$
z = \text{softmax}(\frac{ln(x_i) + g_i}{\tau})
$$

コード

実際に学習してみます。

class GumbelVAE(keras.Model):
    def __init__(self):
        super().__init__()
        self.z_size = 10
        self.temperature = 1

        # --- encoder
        self.enc_layers = [
            kl.Flatten(),
            kl.Dense(128, activation="relu"),
        ]
        self.logits_layer = kl.Dense(self.z_size)

        # --- decoder
        self.dec_layers = [
            kl.Dense(128, activation="relu"),
            kl.Dense(28 * 28 * 1),
            kl.Reshape((28, 28, 1)),
        ]

    def gumbel_inverse(self, x):
        return -tf.math.log(-tf.math.log(x))

    def call(self, x):
        logits = self.encode(x)

        # --- Gumbel-Max trick
        rnd = tf.random.uniform(tf.shape(logits), minval=1e-10, maxval=1.0)
        z = tf.nn.softmax((logits + self.gumbel_inverse(rnd)) / self.temperature)

        return self.decode(z)

    def sample(self, x):
        logits = self.encode(x)

        rnd = tf.random.uniform(tf.shape(logits), minval=1e-10, maxval=1.0)
        logits = logits + self.gumbel_inverse(rnd)

        # 最大値とsoftmaxの確率が同じになる
        z = tf.argmax(logits, axis=-1)
        z = tf.one_hot(z, self.z_size)

        return self.decode(z)

    def encode(self, x):
        for h in self.enc_layers:
            x = h(x)
        return self.logits_layer(x)

    def decode(self, x):
        for h in self.dec_layers:
            x = h(x)
        return x

# model = VAE()
# model = CategoricalVAE()
model = GumbelVAE()

Figure_4.png

ちゃんと学習できていますね。

3. Straight-Through Gradients with Automatic Differentiation

名前が分からなかったので論文からそのまま、タイトルは長いので削った形です。
DreamerV2の論文に記載がある手法で、アルゴリズムは以下。

Figure_5.png

サンプリングには勾配を流さず直接確率を計算する部分だけ流すというかなり直接的な方法ですね。
今回やりたい事をやるには一番簡単かも…。

参考
Mastering Atari with Discrete World Models(論文)
Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation(論文の参照元の論文)

コードは以下です。

class StraightGradVAE(keras.Model):
    def __init__(self):
        super().__init__()
        self.z_size = 10

        # --- encoder
        self.enc_layers = [
            kl.Flatten(),
            kl.Dense(128, activation="relu"),
        ]
        self.logits_layer = kl.Dense(self.z_size)

        # --- decoder
        self.dec_layers = [
            kl.Dense(128, activation="relu"),
            kl.Dense(28 * 28 * 1),
            kl.Reshape((28, 28, 1)),
        ]

    def call(self, x):
        logits = self.encode(x)

        # --- Straight-Through Gradients with Automatic Differentiation
        sample = tf.random.categorical(logits, 1)
        sample = tf.one_hot(tf.squeeze(sample, 1), self.z_size)
        probs = tf.nn.softmax(logits)
        z = sample + probs - tf.stop_gradient(probs)

        return self.decode(z)

    def sample(self, x):
        logits = self.encode(x)
        sample = tf.random.categorical(logits, 1)
        z = tf.one_hot(tf.squeeze(sample, 1), self.z_size)
        return self.decode(z)

    def encode(self, x):
        for h in self.enc_layers:
            x = h(x)
        return self.logits_layer(x)

    def decode(self, x):
        for h in self.dec_layers:
            x = h(x)
        return x

# model = VAE()
# model = CategoricalVAE()
# model = GumbelVAE()
model = StraightGradVAE()

Figure_6.png

これだけで学習できていますね。

4. VQ-VAE

VQ-VAEは特徴を離散的な埋め込み表現にマッピングすることでカテゴリカル分布を表現したVAEとなります。
今までの手法と違うのは、別途新しい離散空間を用意しそこに特徴量をマッピングするという点が異なっています。

参考
【論文解説+Tensorflowで実装】VQ-VAEを理解する | 楽しみながら理解するAI・機械学習入門

解説は…、疲れたので省略します。
参考サイトを見てください。

コードは以下です。

class VQVAE(keras.Model):
    def __init__(self):
        super().__init__()
        self.z_size = 10

        self.num_class = 10
        self.embbed = tf.Variable(
            tf.random_normal_initializer()(shape=(self.z_size, self.num_class)),
            dtype=tf.float32,
            trainable=True,
        )

        # --- encoder
        self.enc_layers = [
            kl.Flatten(),
            kl.Dense(128, activation="relu"),
        ]
        self.logits_layer = kl.Dense(self.z_size)

        # --- decoder
        self.dec_layers = [
            kl.Dense(128, activation="relu"),
            kl.Dense(28 * 28 * 1),
            kl.Reshape((28, 28, 1)),
        ]

        self.loss_tracker = tf.keras.metrics.Mean(name="loss")

    def call(self, x):
        return self.encode(x)

    def _quantized(self, x):
        # (z-e)^2 = z^2 - 2*ze + e^2
        # matmul: (batch, z) * (z, class) = (batch, class)
        d1 = tf.reduce_sum(x**2, axis=1, keepdims=True)
        d2 = 2 * tf.matmul(x, self.embbed)
        d3 = tf.reduce_sum(self.embbed**2, axis=0, keepdims=True)
        distance = d1 - d2 + d3
        encoding_indices = tf.argmin(distance, axis=1)
        q = tf.nn.embedding_lookup(tf.transpose(self.embbed, [1, 0]), encoding_indices)
        return q

    def compute_loss(self, x, y, y_pred, sample_weight):
        encoded_x = y_pred
        z = self._quantized(encoded_x)

        y_pred = self.decode(encoded_x + tf.stop_gradient(z - encoded_x))
        loss_rec = tf.reduce_mean(tf.square(y_pred - y))
        loss_e = tf.reduce_mean(tf.square(tf.stop_gradient(z) - encoded_x))
        loss_q = tf.reduce_mean(tf.square(z - tf.stop_gradient(encoded_x)))
        loss = loss_rec + loss_e + loss_q

        self.loss_tracker.update_state(loss)
        return loss

    def sample(self, x):
        logits = self.encode(x)
        z = self._quantized(logits)
        return self.decode(z)

    def encode(self, x):
        for h in self.enc_layers:
            x = h(x)
        return self.logits_layer(x)

    def decode(self, x):
        for h in self.dec_layers:
            x = h(x)
        return x

# model = VAE()
# model = CategoricalVAE()
# model = GumbelVAE()
# model = StraightGradVAE()
model = VQVAE()

lossの計算が特殊なので compute_loss 関数を実装して別途計算しています。
結果は以下です。

Figure_7.png

おわりに

性能差も見たかったので特徴量は少なめですがあまり差はないイメージですね。
Conv層もちゃんと作れば10種類に分類してその代表画像が出力されるようになるのかな?
誰かの参考になれば幸いです。

5
5
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?