0
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

VAE

Posted at

VAEを触る機会があったので,メモ.

AEとの違い

  • オートエンコーダでは圧縮した潜在変数zの分布が分からなかった
  • VAEでは,潜在変数zを標準正規分布と仮定
    • 標準正規分布をデコーダに入力すれば画像を生成できる
    • 多様体仮説
      • 重要な部分は局所的に集中している
      • 標準正規分布の次元に押し込める

ネットワーク

image
  • 入力画像から平均ベクトルと分散ベクトルを得る.
    • ここでの平均と分散は,潜在空間zを定義するパラメータ
  • しかし,平均・分散から標準正規分布を得ても,サンプリングしないと値を得られない
  • すなわち,誤差逆伝播できない
  • よって,z = μ + εσ で近似する(Reparametrization Trick)
    • ランダムノイズεからサンプリングする
    • 確定的な値が得られ,誤差逆伝播できる
  • Reparametrization Trcik部分のコードは以下のようになる.
def reparameterize(self, mean, var):
    eps = torch.randn(mean.size().to(self.device)
    z = mean + torch.sqrt(var) * eps
    return z

損失関数

  • オートエンコーダの損失関数は再構成誤差のみだったが,VAEでは,再構成誤差に加え,潜在変数が標準正規分布に従うように正規化の誤差も考慮する.
  • 損失関数の導出にはKLダイバージェンス等を使う.(説明略)
  • 正則化の損失関数:image
    • 変分下界の最大化と同じ.
    • これは,μ=1, σ=0の時に最小化する
    • すなわち,エンコーダが出力する潜在変数zが標準正規分布と一致すると損失が0になる.
    • 潜在変数zを標準正規分布に押し込める
  • 損失関数をコードにすると以下.
# 再構成誤差
reconst = -torch.mean(torch.sum(x*torch.log(y)+(1-x)*torch.log(1-y), dim=1))

# 正則化
kl = - 1/2 * torch.mean(torch.sum(1+torch.log(var)-mean**2-var, dim=1))

# 損失関数
L = reconst + kl

結果

  • ぼやける
    • 潜在変数の分布を標準正規分布と仮定
    • ピクセル単位での損失計算

次は,確率分布を明示的にモデル化するのではなく,NNで確率分布を最適化するGANを見ていく.

0
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?