8
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

FactorVAEをpixyzで実装する

Last updated at Posted at 2019-01-20

Disentangling by Factorisingの概要

βVAEはdisentangleできるモデルだが,事前分布に近づける制約を大きくしているため,再構成がぼやけてしまう問題がある.
提案手法であるFactorVAEはTotal Correlation制約をいれることによって,再構成がぼやけてしまうことなしに,disentangleさせることを可能にした.
(詳しい説明は下スライド参照).
thumbnail

実装結果

各zの次元ごとに値を動かしたときの可視化

βVAE(CCI-VAE)

z1はy座標,z4はx座標,z8は回転,z9は大小がそれぞれ対応し,それぞれの要素のみが変化している(disentangleされている)ということがわかる.
チューニングをもっと頑張れば再構成ももっとキレイに生成できたと思われる.

download-1.png
betavae_C_dsprites_z_dim10_gamma80_gif_0.gif

download-2.png
betavae_C_dsprites_z_dim10_gamma80_gif_1.gif

download.png
betavae_C_dsprites_z_dim10_gamma80_gif_2.gif

FactorVAE

In progressです.実験結果がうまくいったら載せます.

pixyzを使った実装紹介

pixyzとは深層生成モデルを簡単に書くことができるライブラリである.
準備として,BetaVAEも簡単に実装した.
可視化も含めたすべての実装はここにあげた.
簡単に実装できる上,コードも非常に少なくてすむのでpixyzは素晴らしい.

pixyz参考

βVAE

Network Architecture

モデルはVAEと全く同じである.
Encoder(正規分布)とDecoder(ベルヌーイ分布)を用意する.

from pixyz.distributions import Normal, Bernoulli, Deterministic
import torch
import torch.nn as nn
from torch.nn import functional as F

class Encoder(Normal):
    def __init__(self, z_dim=10):
        super(Encoder, self).__init__(cond_var=["x"], var=["z"])

        self.z_dim = z_dim
        # encode
        self.conv_e = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),    # 64 ⇒ 32
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # 32 ⇒ 16
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 16 ⇒ 8
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
        )
        self.fc_e = nn.Sequential(
            nn.Linear(128 * 8 * 8, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 2*self.z_dim),
        )

    def forward(self, x):
        x = self.conv_e(x)
        x = x.view(-1, 128 * 8 * 8)
        x = self.fc_e(x)
        mu = x[:, :self.z_dim]
        scale = F.softplus(x[:, self.z_dim:])
        return {"loc": mu, "scale": scale}
        
        
class Decoder(Bernoulli):
    def __init__(self, z_dim=10):
        super(Decoder, self).__init__(cond_var=["z"], var=["x"])

        self.z_dim = z_dim
        # decode
        self.fc_d = nn.Sequential(
            nn.Linear(self.z_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 128 * 8 * 8),
            nn.LeakyReLU(0.2)
        )
        self.conv_d = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )
        
    def forward(self, z):
        h = self.fc_d(z)
        h = h.view(-1, 128, 8, 8)
        return {"probs": self.conv_d(h)}

Distribution

必要な確率分布を準備する.
上で作ったモデルをインスタンス化する.

z_dim=10
beta = 50

# prior model p(z)
loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim, name="p_prior")

E = Encoder(z_dim=z_dim).to(device) # q(z|x)
D = Decoder(z_dim=z_dim).to(device) # p(x|z)

Loss

VAEと異なるのはLossの部分だけ.
βVAEの目的関数は,再構成誤差+β×制約項となる.
β=1のとき,VAEと全く同じ目的関数になることがわかる.
$$
E_{q_\phi(z|x)}\left[\log p_\theta(x|z)\right] - \beta KL(q_\phi(z|x)||p(z))
$$

from pixyz.losses import KullbackLeibler, CrossEntropy
reconst = CrossEntropy(E, D)
kl = KullbackLeibler(E, prior)
loss_cls = reconst.mean() + beta*kl.mean()
print(loss_cls)
# mean(-E_p(z|x)[log p(x|z)]) + mean(KL[p(z|x)||p_prior(z)]) * 50

model

あとはモデルに入れて訓練するだけ.非常にコードが短い.

model = Model(loss_cls, distributions=[E, D], optimizer=optim.Adam, optimizer_params={"lr":5e-4})

for i in range(5):
    for batch_idx, x in tqdm(enumerate(train_loader)):
        x = x.to(device)
        loss = model.train({"x": x})

実は,単純なβVAEだとdisentangleさせるのが難しく,実験がうまくいかなかった.
βをめちゃくちゃ大きくすると,ある程度disentangleするようになったが,めちゃくちゃぼやけてしまった.
そのため,βVAEの改良版であるCCI-VAEを実装することにする.

CCI-VAE

βVAEのロスを工夫することで,disentangleさせることができる.
参考: Understanding disentangling in β-VAE

Loss

βを大きくしすぎると,zがxの情報をどんどん落としてしまうという問題があった.
そのため,β(ここではγ)を大きくしておく代わりに,Cをだんだん大きくしていくことによって,情報ボトルネックを解消したいという気持ちがある.

$$
E_{q_\phi(z|x)}\left[\log p_\theta(x|z)\right] - \gamma | KL(q_\phi(z|x)||p(z)) - C |
$$

Cは訓練中に変えていくので,変数として用意する.
論文にはγの値は1000にすると書いてあったが,あまりうまくいかなかったので80にした.

    reconst = CrossEntropy(E, D)
    kl = KullbackLeibler(E, prior)
    C = Parameter("C")
    loss_cls = reconst.mean() + gamma*(kl.mean()-C).abs()

Model

Cは0から25にかけて線形にepochに比例して大きくしていくようにする.

model = Model(loss_cls, distributions=[E, D], optimizer=optim.Adam, optimizer_params={"lr":5e-4})
N = len(train_loader)

for epoch in range(epoch_num):
    for batch_idx, x in tqdm(enumerate(train_loader)):
        x = x.to(device)
        C_ = 25*(batch_idx+epoch*N)/(epoch_num*N)
        loss = model.train({"x": x, "C": C_})

最終的にこちらの結果を上にのせている.

FactorVAE

Network Architecture

FactorVAEでは,βVAEのEncoderとDecoderに加え,Discriminatorも用意する.
このDiscriminatorはzの各次元を独立にさせるために必要.

class Discriminator(Deterministic):
    def __init__(self, z_dim=10):
        super(Discriminator, self).__init__(cond_var=["z"], var=["t"], name="d")
        self.z_dim = z_dim
        self.model = nn.Sequential(
            nn.Linear(z_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, z):
        t = self.model(z)
        return {"t": t}

Distribution

確率分布を準備する.
ここで,InferenceShuffleDimは,zの各次元ごとにデータをばらばらにして出力するクラスである.

z_dim = 10
gamma = 10

# prior model p(z)
loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim, name="p_prior")

E = Encoder(z_dim=z_dim).to(device) # q(z|x)
G = Decoder(z_dim=z_dim).to(device) # p(x|z)
D = Discriminator(z_dim=z_dim).to(device) # d(t|z)

class InferenceShuffleDim(Deterministic):
    def __init__(self):
        super(InferenceShuffleDim, self).__init__(cond_var=["x_"], var=["z"], name="q_shuffle")

    def permute_dims(self, z):
        B, _ = z.size()
        perm_z = []
        for z_j in z.split(1, 1):
            perm = torch.randperm(B).to(z.device)
            perm_z_j = z_j[perm]
            perm_z.append(perm_z_j)

        return torch.cat(perm_z, 1)        

    def forward(self, x_):
        z = E.sample({"x": x_}, return_all=False)["z"]
        return {"z": self.permute_dims(z)}


E_shuffle = InferenceShuffleDim()

Loss

FactorVAEの目的関数は,VAEの目的関数+Total Correlation(TC)である.
TCは,zのそれぞれの次元を独立にさせようという気持ちがある.

$$
E_{q_\phi(z|x)}\left[\log p_\theta(x|z)\right] - KL(q_\phi(z|x)||p(z)) - \gamma KL(q(z)||{\bar q}(z)) \
q(z) = \int p_{data}(x)q(z|x)dx \
{\bar q}(z) := \Pi_{j=1}^d q(z_j)
$$

reconst = CrossEntropy(E, G)
kl = KullbackLeibler(E, prior)
tc = AdversarialKullbackLeibler(E, E_shuffle, discriminator=D, optimizer=optim.Adam, optimizer_params={"lr":5e-4})
loss_cls = reconst.mean() + kl.mean() + gamma*tc
print(loss_cls)
# mean(-E_p(z|x)[log p(x|z)]) + mean(KL[p(z|x)||p_prior(z)]) + mean(AdversarialKL[p(z|x)||q_shuffle(z|x_)]) * 10 

Model

discriminatorも学習させる必要があるので,いつものモデルの学習に加え,tc.train()も行っている.

model = Model(loss_cls, distributions=[E, G], optimizer=optim.Adam, optimizer_params={"lr":5e-4})

for i in range(10):
    for batch_idx, x in tqdm(enumerate(train_loader)):
        x = x.to(device)
        loss = model.train({"x": x, "x_": x})
        loss_d = tc.train({"x": x, "x_": x})

なぜか実験結果はうまくいっていないので調整中です.

8
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?