離散空間の勾配を計算したい状況があったのでまとめてみました。
タイトルの通り3種類の方法をとり上げますが、あまり情報がなく他にもあるかもしれません…。
1. ReparameterizationTrick(VAE)
まずは連続空間の ReparameterizationTrick を見ていきます。
これは、確率分布を微分可能な関数に置き換えるテクニックで、VAE(Variational Autoencoder)でよく登場するテクニックです。
VAEを簡単に言うと、教師なし学習における特徴抽出の一種で、特徴を標準正規分布上で表現できるように特徴抽出します。
特徴は正規分布に従うので連続空間となります。
これを愚直に表現すると以下の問題が発生します。
これを解決する手法が ReparameterizationTrick で、正規分布上でも勾配を流すことができるテクニックです。
MNISTによるサンプルコード
参考:TensorFlow > 学ぶ > TensorFlow Core > チュートリアル > 畳み込み変分オートエンコーダ
MNISTで実際に実装してみます。
また以降はこのコードをベースにモデルのみを変更して同じコードを使いまわしていきます。
VAEを正確に実装することが目的ではないので以下の違いがある点は注意してください。
- Conv2D層はありません。Dense層が1層のみで作成しています。
- 正則化項に該当するKL損失は省略しています。(なのでこれは標準正規分布には従いません)
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.datasets import mnist
kl = keras.layers
# データの読み込みと前処理
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images = train_images.reshape((60000, 28, 28, 1)).astype("float32") / 255
test_images = test_images.reshape((10000, 28, 28, 1)).astype("float32") / 255
class VAE(keras.Model):
def __init__(self):
super().__init__()
self.z_size = 10
# --- encoder
self.enc_layers = [
kl.Flatten(),
kl.Dense(128, activation="relu"),
]
self.mean_layer = kl.Dense(self.z_size)
self.log_stddev_layer = kl.Dense(self.z_size)
# --- decoder
self.dec_layers = [
kl.Dense(128, activation="relu"),
kl.Dense(28 * 28 * 1),
kl.Reshape((28, 28, 1)),
]
def call(self, x):
mean, stddev = self.encode(x)
# tf.random.normalはshapeがNoneだとエラーになる
if mean.shape[0] is None:
z = mean + stddev * 0
else:
# --- reparameterization trick
e = tf.random.normal(shape=mean.shape)
z = mean + stddev * e
return self.decode(z)
def sample(self, x):
# 乱数を使わない場合は平均をそのまま使う
mean, stddev = self.encode(x)
return self.decode(mean)
def encode(self, x):
for h in self.enc_layers:
x = h(x)
mean = self.mean_layer(x)
log_stddev = self.log_stddev_layer(x)
stddev = tf.math.exp(log_stddev)
return mean, stddev
def decode(self, x):
for h in self.dec_layers:
x = h(x)
return x
model = VAE()
model.compile(optimizer="adam", loss="mse")
# モデルの訓練
model.fit(train_images, train_images, epochs=10, batch_size=64)
test_loss = model.evaluate(test_images, test_images)
print("Test accuracy:", test_loss) # Test accuracy: 0.020644141361117363
# 表示
pred_images = model.sample(test_images[:8])
fig = plt.figure(figsize=(4, 4))
for i in range(8):
plt.subplot(4, 4, i + 1)
plt.imshow(test_images[i, :, :, 0], cmap="gray")
plt.axis("off")
for i in range(8):
plt.subplot(4, 4, 8 + i + 1)
plt.imshow(pred_images[i, :, :, 0], cmap="gray")
plt.axis("off")
plt.show()
結果は以下です。
上2段が入力で下2段が出力結果です。
学習は出来ていそうですね。
2. 学習できないCategoricalVAE
VAEでは特徴を正規分布(連続空間)と仮定しましたが、カテゴリカル分布(離散空間)と仮定して作成します。
カテゴリカル分布なので例えば特徴数を10とすれば0~9の値をとります。
どの値を取るかはsoftmaxで確率的に表現し、出力側では確率で決まった値をonehot化して渡します。
コードは以下です。
class CategoricalVAE(keras.Model):
def __init__(self):
super().__init__()
self.z_size = 10
# --- encoder
self.enc_layers = [
kl.Flatten(),
kl.Dense(128, activation="relu"),
]
self.logits_layer = kl.Dense(self.z_size)
# --- decoder
self.dec_layers = [
kl.Dense(128, activation="relu"),
kl.Dense(28 * 28 * 1),
kl.Reshape((28, 28, 1)),
]
def call(self, x):
logits = self.encode(x)
sample = tf.random.categorical(logits, 1)
z = tf.one_hot(tf.squeeze(sample, 1), self.z_size)
return self.decode(z)
def sample(self, x):
return self.call(x)
def encode(self, x):
for h in self.enc_layers:
x = h(x)
return self.logits_layer(x)
def decode(self, x):
for h in self.dec_layers:
x = h(x)
return x
# --- modelを変更するだけで実行できます。
#model = VAE()
model = CategoricalVAE()
学習は出来ませんが、エラーなく実行することは出来ました。
学習結果は以下です。
勾配が流れないので学習できていませんね。
ちなみに以下警告も出力され、いくつかの変数で勾配が流れていないことを指摘されます。
WARNING:tensorflow:Gradients do not exist for variables ['categorical_vae/dense/kernel:0', 'categorical_vae/dense/bias:0', 'categorical_vae/dense_1/kernel:0', 'categorical_vae/dense_1/bias:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument?
3. Gumbel-Max Trick
離散的なReparameterizationTrickを検索したら最初に見つけた手法です。
ざっくりいうとGumbel-Softmaxを使うことでReparameterizationTrickを行う手法です。
参考
・Gumbel-Max Trick(ガンベル最大トリック)を理解する | 楽しみながら理解するAI・機械学習入門
・Categorical Reparameterization with Gumbel-Softmax | ご注文は機械学習ですか?
Gumbel-Softmax分布
SoftmaxとGumbel-Softmaxは以下です。
\text{softmax}(x)_i = \frac{e^{x_i}}{\sum_{j=1}^{K} e^{x_j}}
\text{gumbel_softmax}(x, \tau)_i = \frac{e^{(x_i + g_i)/\tau}}{\sum_{j=1}^{K} e^{(x_j + g_j)/\tau}}
$g$ はGumbel関数からのサンプルを表し、$\tau$ は乱数を制御する温度パラメータです。
このGumbel分布ですが、各要素において以下の$z$が最大値をとる確率はsoftmaxの確率と一致します。
$$
z_i = x_i + g_i
$$
$g$ はGumbel分布に従う乱数で、一様乱数をGumbel分布の逆関数に入れることで取得できます。
\begin{align}
u \sim Uniform(0,1) \\
g = -ln(-ln(u))
\end{align}
実際に見てみます。
import matplotlib.pyplot as plt
import numpy as np
def softmax(x):
exp_x = np.exp(x - np.max(x))
return exp_x / np.sum(exp_x, axis=0)
def gumbel_inverse(x):
return -np.log(-np.log(x))
x = np.array([1.2, 2.3, 1.4, 3.3, 0.1])
num_samples = 100000
# --- random softmax
y_softmax = np.random.choice(len(x), size=num_samples, p=softmax(x))
# --- random gumbel
y_gumbel = []
for _ in range(num_samples):
rnd = np.random.uniform(size=(len(x)))
z = x + gumbel_inverse(rnd)
y_gumbel.append(np.argmax(z))
# --- plot
plt.hist([y_softmax, y_gumbel], label=["softmax", "gumbel"])
plt.legend()
plt.show()
見事に一致していますね。
最後にサンプリングと勾配で使う確率ベクトルの式を書いておきます。
・サンプリング
$$
z = \text{onehot}(\underset{i}{\text{argmax}}(ln(x_i) + g_i))
$$
・勾配で使う確率ベクトル
$$
z = \text{softmax}(\frac{ln(x_i) + g_i}{\tau})
$$
コード
実際に学習してみます。
class GumbelVAE(keras.Model):
def __init__(self):
super().__init__()
self.z_size = 10
self.temperature = 1
# --- encoder
self.enc_layers = [
kl.Flatten(),
kl.Dense(128, activation="relu"),
]
self.logits_layer = kl.Dense(self.z_size)
# --- decoder
self.dec_layers = [
kl.Dense(128, activation="relu"),
kl.Dense(28 * 28 * 1),
kl.Reshape((28, 28, 1)),
]
def gumbel_inverse(self, x):
return -tf.math.log(-tf.math.log(x))
def call(self, x):
logits = self.encode(x)
# --- Gumbel-Max trick
rnd = tf.random.uniform(tf.shape(logits), minval=1e-10, maxval=1.0)
z = tf.nn.softmax((logits + self.gumbel_inverse(rnd)) / self.temperature)
return self.decode(z)
def sample(self, x):
logits = self.encode(x)
rnd = tf.random.uniform(tf.shape(logits), minval=1e-10, maxval=1.0)
logits = logits + self.gumbel_inverse(rnd)
# 最大値とsoftmaxの確率が同じになる
z = tf.argmax(logits, axis=-1)
z = tf.one_hot(z, self.z_size)
return self.decode(z)
def encode(self, x):
for h in self.enc_layers:
x = h(x)
return self.logits_layer(x)
def decode(self, x):
for h in self.dec_layers:
x = h(x)
return x
# model = VAE()
# model = CategoricalVAE()
model = GumbelVAE()
ちゃんと学習できていますね。
3. Straight-Through Gradients with Automatic Differentiation
名前が分からなかったので論文からそのまま、タイトルは長いので削った形です。
DreamerV2の論文に記載がある手法で、アルゴリズムは以下。
サンプリングには勾配を流さず直接確率を計算する部分だけ流すというかなり直接的な方法ですね。
今回やりたい事をやるには一番簡単かも…。
参考
・Mastering Atari with Discrete World Models(論文)
・Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation(論文の参照元の論文)
コードは以下です。
class StraightGradVAE(keras.Model):
def __init__(self):
super().__init__()
self.z_size = 10
# --- encoder
self.enc_layers = [
kl.Flatten(),
kl.Dense(128, activation="relu"),
]
self.logits_layer = kl.Dense(self.z_size)
# --- decoder
self.dec_layers = [
kl.Dense(128, activation="relu"),
kl.Dense(28 * 28 * 1),
kl.Reshape((28, 28, 1)),
]
def call(self, x):
logits = self.encode(x)
# --- Straight-Through Gradients with Automatic Differentiation
sample = tf.random.categorical(logits, 1)
sample = tf.one_hot(tf.squeeze(sample, 1), self.z_size)
probs = tf.nn.softmax(logits)
z = sample + probs - tf.stop_gradient(probs)
return self.decode(z)
def sample(self, x):
logits = self.encode(x)
sample = tf.random.categorical(logits, 1)
z = tf.one_hot(tf.squeeze(sample, 1), self.z_size)
return self.decode(z)
def encode(self, x):
for h in self.enc_layers:
x = h(x)
return self.logits_layer(x)
def decode(self, x):
for h in self.dec_layers:
x = h(x)
return x
# model = VAE()
# model = CategoricalVAE()
# model = GumbelVAE()
model = StraightGradVAE()
これだけで学習できていますね。
4. VQ-VAE
VQ-VAEは特徴を離散的な埋め込み表現にマッピングすることでカテゴリカル分布を表現したVAEとなります。
今までの手法と違うのは、別途新しい離散空間を用意しそこに特徴量をマッピングするという点が異なっています。
参考
・【論文解説+Tensorflowで実装】VQ-VAEを理解する | 楽しみながら理解するAI・機械学習入門
解説は…、疲れたので省略します。
参考サイトを見てください。
コードは以下です。
class VQVAE(keras.Model):
def __init__(self):
super().__init__()
self.z_size = 10
self.num_class = 10
self.embbed = tf.Variable(
tf.random_normal_initializer()(shape=(self.z_size, self.num_class)),
dtype=tf.float32,
trainable=True,
)
# --- encoder
self.enc_layers = [
kl.Flatten(),
kl.Dense(128, activation="relu"),
]
self.logits_layer = kl.Dense(self.z_size)
# --- decoder
self.dec_layers = [
kl.Dense(128, activation="relu"),
kl.Dense(28 * 28 * 1),
kl.Reshape((28, 28, 1)),
]
self.loss_tracker = tf.keras.metrics.Mean(name="loss")
def call(self, x):
return self.encode(x)
def _quantized(self, x):
# (z-e)^2 = z^2 - 2*ze + e^2
# matmul: (batch, z) * (z, class) = (batch, class)
d1 = tf.reduce_sum(x**2, axis=1, keepdims=True)
d2 = 2 * tf.matmul(x, self.embbed)
d3 = tf.reduce_sum(self.embbed**2, axis=0, keepdims=True)
distance = d1 - d2 + d3
encoding_indices = tf.argmin(distance, axis=1)
q = tf.nn.embedding_lookup(tf.transpose(self.embbed, [1, 0]), encoding_indices)
return q
def compute_loss(self, x, y, y_pred, sample_weight):
encoded_x = y_pred
z = self._quantized(encoded_x)
y_pred = self.decode(encoded_x + tf.stop_gradient(z - encoded_x))
loss_rec = tf.reduce_mean(tf.square(y_pred - y))
loss_e = tf.reduce_mean(tf.square(tf.stop_gradient(z) - encoded_x))
loss_q = tf.reduce_mean(tf.square(z - tf.stop_gradient(encoded_x)))
loss = loss_rec + loss_e + loss_q
self.loss_tracker.update_state(loss)
return loss
def sample(self, x):
logits = self.encode(x)
z = self._quantized(logits)
return self.decode(z)
def encode(self, x):
for h in self.enc_layers:
x = h(x)
return self.logits_layer(x)
def decode(self, x):
for h in self.dec_layers:
x = h(x)
return x
# model = VAE()
# model = CategoricalVAE()
# model = GumbelVAE()
# model = StraightGradVAE()
model = VQVAE()
lossの計算が特殊なので compute_loss 関数を実装して別途計算しています。
結果は以下です。
おわりに
性能差も見たかったので特徴量は少なめですがあまり差はないイメージですね。
Conv層もちゃんと作れば10種類に分類してその代表画像が出力されるようになるのかな?
誰かの参考になれば幸いです。