VAE(Variational Auto Encoder)のKL距離の理解に関して時間がかかったので自分の理解をメモしておく。
間違いがあれば指摘ください。
#VAE
VAEでは潜在変数zを連続な正規分布への変換で示す。
例えば二次元の潜在変数(z1,z2)を直前の潜在変数(μ1, μ2, σ1, σ2)を使って以下のように定義します。
$z=[z1, z2]=[μ1 + σ1 * randn(), μ2 + σ2 * randn()]$
ここでrandn()は平均値0、標準偏差1の正規分布乱数である。
こうする事によってVAEのDecoderは潜在変数z面の離散的かつ局所的な点からの変換ではなく、連続的な潜在変数z面からの変換が学習されるだろうことは何となくわかります。
ここまでは特に問題ありませんでした。
#損失関数にいきなり現れるKLダイバージェンス(KL距離)
次にVAEは損失関数にKL距離を足し学習を始めます。
ところが最初自分はこれがよく分かりませんでした。
どうして損失関数にKL距離を足す必要があるのか、そもそもKL距離って何だ(そこから?)というレベルでした。VAEを解説してくださる方の説明もKL距離に関して出てくる数式が難しく、余計にKL距離の理解を放棄する原因になりました。一度理解してみると解説で正しく書かれているんですが、残念ながら自分にはなかなか意味が分からず理解が進みませんでした。
同じ疑問を持たれた方はまずKL距離に対して調べるのがよろしいかと思います。
- KL距離はKLダイバージェンス、カルバック・ライブラー情報量ともいう。
- KL距離とはある分布P(x)とある分布Q(x)の距離二乗(のようなもの)を示す尺度である。
- 分布P(x)と分布Q(x)が一致するときKL距離は0になる。
- 逆に分布P(x)と分布Q(x)が離れるほど、(おおよそ)分布距離間の二乗に比例する形で大きくなる。
- 損失関数にKL距離を加えた時、KL距離は小さな値(ゼロ)に近づくように重みが学習される。すなわち潜在変数zの分布P(x)と対象となる分布Q(x)が一致するように収束する。
- VAEにおいては分布Q(x)は平均値0、標準偏差1の正規分布とする。
- つまりは損失関数にKL距離を加えると学習後の潜在変数zの分布P(x)(「VAEのDecoderの入力」)は平均値0、標準偏差1の正規分布に近づく形で生成される。(入力の正則化)
- 分布Q(x)を平均値0、標準偏差1の正規分布とし、潜在変数zの分布P(x)を平均値$μ$、標準偏差$σ$の正規分布とすればこの二つの分布のKL距離は$\frac{1}{2}(-2\lnσ+μ^2+σ^2-1)$となる。(計算略)
- この時、KL距離の$μ$および$σ$の偏微分の値が0になるのはそれぞれ$μ=0,σ=1$であるからKL距離はこの値で極小値を取ることが確認でき、その時のKL距離はゼロである。
- 実装では$σ$がマイナスになると$\lnσ$を求められないためか、$\ln(σ)^2$を$logvar$と置いてKL距離を$\frac{1}{2}(-logvar+μ^2+\exp(logvar)-1)$とする記述が主流である。
#所感
VAEだと潜在変数zが学習によって収束する分布を平均値0、標準偏差1の分布とします。
ところで収束させる分布は連続的であるならば別に何でもいいのではないかと個人的に思います。例えば潜在変数の分布の行き先を平均値0.5、標準偏差0.1の正規分布にしたとしても、-1から1までの一様分布にしたとしても、潜在変数z面の連続にするという目的に関せば特に問題はないわけです。
(入力の正則化という目的からしたら的外れかもしれませんけど…)
その場合、従来のVAEからどのように変更すればよいでしょうか?
結論を言うとVAEの損失関数に足すKL距離の値を変えれば対応できる事になります。
例えば仮に分布Q(x)を平均値0.5、標準偏差0.1の正規分布とするならば損失関数に足すKL距離は$\frac{1}{2}(-2\ln10σ+10^2((μ-0.5)^2+σ^2)-1)$とすれば$μ=0.5,σ=0.1$に極小値0を持つ関数になり、潜在変数zのP(x)分布の行き先を平均値0.5、標準偏差0.1の正規分布Q(x)にする事ができるかと思います。