VAEを触る機会があったので,メモ.
AEとの違い
- オートエンコーダでは圧縮した潜在変数zの分布が分からなかった
- VAEでは,潜在変数zを標準正規分布と仮定
- 標準正規分布をデコーダに入力すれば画像を生成できる
- 多様体仮説
- 重要な部分は局所的に集中している
- 標準正規分布の次元に押し込める
ネットワーク
- 入力画像から平均ベクトルと分散ベクトルを得る.
- ここでの平均と分散は,潜在空間zを定義するパラメータ
- しかし,平均・分散から標準正規分布を得ても,サンプリングしないと値を得られない
- すなわち,誤差逆伝播できない
- よって,z = μ + εσ で近似する(Reparametrization Trick)
- ランダムノイズεからサンプリングする
- 確定的な値が得られ,誤差逆伝播できる
- Reparametrization Trcik部分のコードは以下のようになる.
def reparameterize(self, mean, var):
eps = torch.randn(mean.size().to(self.device)
z = mean + torch.sqrt(var) * eps
return z
損失関数
- オートエンコーダの損失関数は再構成誤差のみだったが,VAEでは,再構成誤差に加え,潜在変数が標準正規分布に従うように正規化の誤差も考慮する.
- 損失関数の導出にはKLダイバージェンス等を使う.(説明略)
- 正則化の損失関数:
- 変分下界の最大化と同じ.
- これは,μ=1, σ=0の時に最小化する
- すなわち,エンコーダが出力する潜在変数zが標準正規分布と一致すると損失が0になる.
- 潜在変数zを標準正規分布に押し込める
- 損失関数をコードにすると以下.
# 再構成誤差
reconst = -torch.mean(torch.sum(x*torch.log(y)+(1-x)*torch.log(1-y), dim=1))
# 正則化
kl = - 1/2 * torch.mean(torch.sum(1+torch.log(var)-mean**2-var, dim=1))
# 損失関数
L = reconst + kl
結果
- ぼやける
- 潜在変数の分布を標準正規分布と仮定
- ピクセル単位での損失計算
次は,確率分布を明示的にモデル化するのではなく,NNで確率分布を最適化するGANを見ていく.