34
31

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

VQ-VAEの理解

Posted at

#はじめに
VQ-VAEはVector Quantised(ベクトル量子化)という手法を使ったVAEです。
従来のVAEでは潜在変数zを正規分布(ガウス分布)のベクトルになるような学習を行いますが、VQ-VAEでは潜在変数を離散化した数値になるような学習を行うVAEです。モデルは(Encoder)-(量子化部分)-(Decoder)から成りますが、Encoder、Decoderについては畳み込みを行うVAEと大きく変わりません。
VQ-VAEの論文と実装をチラッと見たところ、その量子化担当分の作りの理解が二転三転したため備忘録として自分の理解をまとめます。
#Embeddingって何
VQ-VAEを語る上でおそらく避けて通れないのがEmbeddingでしょう。
自分のようにあまり理解してないとこれがどんな処理か微妙に分かりづらいです。

自分には実例を見るのが最も分かりやすかったです。
例えば以下の様にinput行列$(2,4)$で数値はindex値であり、embedding行列が$(10,3)$の場合を考えます。
この場合、input行列をonehot化させて$(2,4,10)$にしてembedding行列$(10,3)$を掛けるとembedding後$(2,4,3)$行列が作成されます。
要するにEmbeddingとは入力をonehot化してembedding行列を掛けたものに過ぎません。

>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
tensor([[[-0.0251, -1.6902,  0.7172],
         [-0.6431,  0.0748,  0.6969],
         [ 1.4970,  1.3448, -0.9685],
         [-0.3677, -2.7265, -0.1685]],

        [[ 1.4970,  1.3448, -0.9685],
         [ 0.4362, -0.4004,  0.9400],
         [-0.6431,  0.0748,  0.6969],
         [ 0.9124, -2.3616,  1.1151]]])

#VQ-VAEの最初の理解(間違い)
image.png
最初に自分は以上の図のような理解をしました。先に言うと間違いです。
入力$z_e$を$(10,10,32)$とし、embedding行列を$(32,128)$とし、潜在空間で$128$のベクトル量子化(潜在空間を離散化)する事を考えました。
入力$z_e$にembedding行列を掛けてその中から最も1に近い場所のインデックスを取得します。(任意のonehotベクトルの内、最も距離が近いonehotベクトルのインデックス)。それが$q(z|x)$になりまして、これは$(10,10)$の行列で値はインデックス値です。
これをonehot化してembedding行列の逆行列を掛けたものを$z_q$とします。ここで$z_q$のonehot化させてembedding行列を掛けるという処理は最初に説明したembedding処理自体にほかなりません。
潜在空間の離散化が上手くできていれば出力$z_q$は入力$z_e$に近づくはずなので損失関数は$(z_e-z_q)^2$です。

図形の変化をnumpyで書けば下記のようになります。

import numpy as np

input = np.random.rand(10,10,32)
embed = np.random.rand(32,128)
embed_inv = np.linalg.pinv(embed)
dist = (np.dot(input, embed) - np.ones((10,10,128)))**2
embed_ind = np.argmin(dist, axis=2)
embed_onehot = np.identity(128)[embed_ind]
output = np.dot(embed_onehot, embed_inv)

print("input.shape=", input.shape)
print("embed.shape=", embed.shape)
print("embed_inv.shape=", embed_inv.shape)
print("dist.shape=", dist.shape)
print("embed_ind.shape=", embed_ind.shape)
print("embed_onehot.shape=", embed_onehot.shape)
print("output.shape=", output.shape)
----------------------------------------------
input.shape= (10, 10, 32)
embed.shape= (32, 128)
embed_inv.shape= (128, 32)
dist.shape= (10, 10, 128)
embed_ind.shape= (10, 10)
embed_onehot.shape= (10, 10, 128)
output.shape= (10, 10, 32)

#何が間違っているか
実際の実装と見比べれば上記の解釈は正しくありません。
理由のひとつはembedding行列の逆行列を求めるという処理がおそらく実際には不可能であることです。
このため、embedding行列の逆行列を使わない方法で$q(z|x)$および$z_q$を求める必要があります。

もうひとつは$q(z|x)$の論文の定義と異なっていることです。
論文には以下の様にあり、
image.png
上記解釈における下記の式が正しくありません。
$q(z|x)=argmin((z_e \cdot e_{mbed} - I)^2)$

ここでargminの中にembedding行列の逆行列$e_{mbed\ inv}$を掛ける事を考えると以下の様に整理できます。
$((z_e \cdot e_{mbed} - I)^2 \cdot e_{mbed\ inv}^2)=(z_e \cdot e_{mbed} \cdot e_{mbed\ inv}- I \cdot e_{mbed\ inv})^2=(z_e - e_{mbed\ inv})^2$
これは論文中にある式と等しいです。

そうするとembedding行列とその逆行列の呼び方は入れ替えたほうが以降都合がいいです。
つまり以降$e_{mbed\ inv}$を$e_{mbed}$と呼び、$e_{mbed}$を$e_{mbed\ inv}$と呼ぶことにします。

#VQ-VAEの二番目の理解
上記の訂正により、以下の理解になりました。
この時、$q(z|x)$および$z_q$を求める両方の式に$e_{mbed\ inv}$が入ってないことに注意してください。
$q(z|x)$および$z_q$はどちらも入力$z_e$と$e_{mbed}$で計算できるので、その逆行列を計算する必要性がなくなります。
具体的には$q(z|x)$は$z_e$と$e_{mbed}$から、$z_q$は$q(z|x)$と$e_{mbed}$から計算されます。

image.png

#損失関数の勾配伝播
さて、ベクトル量子化に関する損失関数は量子化前後の差分である$(z_e-z_q)^2$かと思ったら実は違います。
$sg()$という勾配を止める関数を使って$(sg(z_e)-z_q)^2+(z_e-sg(z_q))^2$のように表現されます。
これは$(z_e-z_q)^2$と何が違うのかというとおそらく$q(z|x)$から$z_e$と$e_{mbed}$の誤差逆伝播を計算するのが困難な為、思い切ってその部分の勾配伝達を切っているものと考えられます。

また、損失関数の第2項目と第3項目で更新する内容が異なります。
第2項目はembedding行列を更新しますが、入力(Encoder)へは勾配が伝達しません。
第3項目は入力(Encoder)へ勾配が伝達しますが、embedding行列を更新しません。
損失関数の第1項目に関してはDecorderから始まり$z_q→z_e$へ量子化部分をスキップしてEncoderへと伝わっていると思われます。ただ、これに関しては通常のAutoEncoderの損失と変わりません。
image.png
image.png

#argmin関数
配列の最も小さな値のインデックス値を取るargmin関数をヘヴィサイドの階段関数を使って書いてみます。

argmin(a,b) = H(b-a) \cdot 0 + H(a-b) \cdot 1 \\
argmin(a,b,c) = H(b-a) \cdot H(c-a) \cdot 0 + H(a-b) \cdot H(c-b) \cdot 1 + H(a-c) \cdot H(b-c) \cdot 2\\
H(x) =\left\{
\begin{array}{ll}
1 & (x \geq 0) \\
0 & (x \lt 0)
\end{array}
\right.

ここで最小値を引く項がすべての積の値が1になり残り、最小値でない値を引く場合はいずれかがゼロになり残りません。従って$argmin(a_{1},\cdots , a_{128})$は高次の階段関数の積になるので、たとえ勾配計算時に階段関数を連続微分可能な関数に置き換えたとしても(例えばsigmoid関数)これは微分困難でないかと思われます。
しかし、これは先ほどの説明の通り$sg()$という勾配を止める関数を使うことで考えなくて良くなります。

#まとめ
実装を流し見して理解した気でいたのだが、当初考えていたembedding行列は実際のembedding行列の逆行列であったのに今更ながらに気付きました。
勘違いしやすいかもしれませんが、embedding行列は入力をベクトル量子化する際に掛ける行列ではありません。一旦、量子化した潜在変数を潜在空間に変換する際に掛ける方がembedding行列です。
あと、ベクトル量子化に関してピクセル単位の物体認識であるSemantic Segmentationと似ている気がしました。Semantic Segmentationではsoftmaxを使って各ピクセル毎にonehot的なベクトルを生成しますが、VQでは距離二乗とargminを使って量子化ベクトルを生成します。

#参考:pytorch VQ-VAE
実際の実装例より

class Quantize(nn.Module):
    def __init__(self, dim, n_embed, decay=0.99, eps=1e-5):
        embed = torch.randn(dim, n_embed)
        ...

    def forward(self, input):
        flatten = input.reshape(-1, self.dim)
        dist = (
            flatten.pow(2).sum(1, keepdim=True)
            - 2 * flatten @ self.embed
            + self.embed.pow(2).sum(0, keepdim=True)
        )
        _, embed_ind = (-dist).max(1)
        embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
        embed_ind = embed_ind.view(*input.shape[:-1])
        quantize = self.embed_code(embed_ind)
        ...
        diff = (quantize.detach() - input).pow(2).mean()
        quantize = input + (quantize - input).detach()

        return quantize, diff, embed_ind

    def embed_code(self, embed_id):
        return F.embedding(embed_id, self.embed.transpose(0, 1))
34
31
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
34
31

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?