概要
GANは、DiscriminatorとGeneratorを競わせることで双方のモデルを学習させていきます。
双方が切磋琢磨に学習していく...のが理想的ですが、実際はDiscriminatorの方が先に賢くなってしまい、Generatorが作成したデータを偽物だと正しく判別しすぎるがために、Generatorがうまく学習しないらしいです。
そこで、Discriminatorにハンデを与えることで双方の実力をトントンにし、Generatorの学習を向上させるVDBという学習方法が提案されました。
VDBを用いて学習したGANをVGANと呼びます。
関連研究
上の問題を解決するために、以下のような手法が提案されています。
- 代替の損失関数を提案(Mao et al., 2016; Zhao et al., 2016; Arjovsky et al., 2017)
- 安定性と収束性を改善するために、勾配ペナルティーなどの正則化を組み込む(Kodali et al, 2017; Gulrajani et al., 2017a; Mescheder et al., 2018)
- 再構築損失(Cheet al., 2016)
- およびその他の無数のヒューリスティック(Sønderby et al., 2016; Salimans et al., 2016; Arjovsky& Bottou, 2017; Berthelot et al., 2017)
GAN
以下のブログがわかりやすかったです。
PyTorch (12) Generative Adversarial Networks (MNIST) - 人工知能に関する断創録
VGAN
$x$をDiscriminatorが判別したいデータとします。
Discriminatorに与えるデータは$x$そのものではなく、$x$を与えられたEncoderが出力した潜在変数$z$です。
これがDiscriminatorに対するハンデとなります。
ここで用いられているDiscriminatorを以後VD(Variational Discriminator)と表記します。
Encoder
VDのボトルネックは、情報ボトルネック(Tishby & Za-slavsky, 2015)(参考)に基づいており、入力との相互情報量を最小化するために内部表現を正則化する技術である。直感的には,圧縮された表現は,元の入力に存在する無関係な"望ましくないもの"を無視することで,汎化を向上させることができます。
なんだかVAEみたい。(実際、VAEからヒントを得たらしいです)
$x$ ⇒ (Encoder) ⇒ $x'$ ⇒ (Conv) ⇒ $\mu, log\sum$ ⇒ $z$ ⇒ (Discriminator) ⇒ 0/1
VDB
Discriminatorの拡張部分
\mu = Conv_{mean}(x) \\
log\sum = Conv_{logvar}(x) \\
noize = [N(0,1),]*|\mu| \\
z = e^{0.5*log\sum}*noize + \mu
mean = self.conv_mean(x).view(-1, 128 * 7 * 7) # mean
logvar = self.conv_logvar(x).view(-1, 128 * 7 * 7) # logvar
noise = torch.randn(mean.size(), device=device)
z = (0.5 * logvar).exp() * noise + mean
VDBの目的関数
GANの目的関数:
$$\underset{G}{max}~\underset{D}{min}~~E_{x〜p^*(x)}\ [-logD(x)]+E_{x〜G(x)}[-log(1-D(G(x))]$$
Encoderを導入することで情報ボトルネックを取り入れます。
Variational Discriminator Bottleneck(VDB)の目的関数:
$$\underset{D,E}{min}~\underset{\beta\geq0}{max}~~[E_{x〜p^*(x)}\ [-logD(x)]]+E_{x〜G(x)}[E_{z〜E(z|x)}[-log(1-D(G(x))]]+\beta(E_{x〜\bar{p}(x)}[KL[E(z|x)||r(z)]]-I_c)$$
二重勾配降下法を用いて$\beta$を適応的に更新し、相互情報$I_c$の特定の制約条件を強制します。
D,E \leftarrow \underset{D,E}{arg~min}~\mathcal{L}(D,E,\beta)\\
\beta \leftarrow max(0,~\beta+\alpha_{\beta}(E_{x〜\bar{p}(x)}[KL[E(z|x)||r(z)]]-I_c)) \\ \\
\mathcal{L}(D,E,\beta) \leftarrow~[E_{x〜p^*(x)}\ [-logD(x)]]+E_{x〜G(x)}[E_{z〜E(z|x)}[-log(1-D(G(x))]]+\beta(E_{x〜\bar{p}(x)}[KL[E(z|x)||r(z)]]-I_c)
$\alpha_\beta$はステップサイズ。
ちなみに、Generatorの目的関数$E_{x〜G(x)}[E_{z〜E(z|x)}[-log(1-D(G(x))]]$はKLペナルティを除いていますが、Encoderの分布の平均$μ_{E(x)}$でDを評価することで期待値を近似することで十分らしいです(よくわからない)。
Discriminatorの損失関数
BCE\_loss = BCE(out, label) \\
normal\_D\_loss = mean(BCE\_loss) \\
KLdiv\_loss = -0.5\times mean(1+log\sum-\mu^2-e^{log\sum}) \\
KLdiv\_loss = mean(KLdiv\_loss - I_c) \\ \\
loss = normal\_D\_loss + \beta * KLdiv\_loss
I_c = 0.1
beta = 0 #1.0
def VDB_loss(out, label, mean, logvar, beta):
normal_D_loss = torch.mean(F.binary_cross_entropy(out, label))
kldiv_loss = - 0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
kldiv_loss = kldiv_loss.mean() - I_c
final_loss = normal_D_loss + beta * kldiv_loss
return final_loss, kldiv_loss.detach()
$\beta$を更新する。
alpha = 1e-5
D_real, mean, logvar = D(real_images)
D_real_loss, loss_kldiv_real = D_criterion(D_real, y_real, mean, logvar, beta)
fake_images = G(z)
D_fake, mean, logvar = D(fake_images.detach())
D_fake_loss, loss_kldiv_fake = D_criterion(D_fake, y_fake, mean, logvar, beta)
loss_kldiv = loss_kldiv_real.item() + loss_kldiv_fake.item()
beta = max(0.0, beta + alpha * loss_kldiv)
最後の更新式を見る感じ、通常のD_lossとKLダイバージェンスlossを$\beta$で乗じた値の和がVDの損失関数になりそうです。
ソースコード(Pytorch)
githubにあったコードを参考にしました。