0
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

GANのDiscriminatorにハンデを与える:VGAN

Last updated at Posted at 2021-12-16

概要

Variational Discriminator Bottleneck: Improving Imitation Learning, Inverse RL, and GANs by Constraining Information Flow

GANは、DiscriminatorとGeneratorを競わせることで双方のモデルを学習させていきます。

双方が切磋琢磨に学習していく...のが理想的ですが、実際はDiscriminatorの方が先に賢くなってしまい、Generatorが作成したデータを偽物だと正しく判別しすぎるがために、Generatorがうまく学習しないらしいです。

そこで、Discriminatorにハンデを与えることで双方の実力をトントンにし、Generatorの学習を向上させるVDBという学習方法が提案されました。

VDBを用いて学習したGANをVGANと呼びます。

関連研究

上の問題を解決するために、以下のような手法が提案されています。

  • 代替の損失関数を提案(Mao et al., 2016; Zhao et al., 2016; Arjovsky et al., 2017)
  • 安定性と収束性を改善するために、勾配ペナルティーなどの正則化を組み込む(Kodali et al, 2017; Gulrajani et al., 2017a; Mescheder et al., 2018)
  • 再構築損失(Cheet al., 2016)
  • およびその他の無数のヒューリスティック(Sønderby et al., 2016; Salimans et al., 2016; Arjovsky& Bottou, 2017; Berthelot et al., 2017)

GAN

以下のブログがわかりやすかったです。

PyTorch (12) Generative Adversarial Networks (MNIST) - 人工知能に関する断創録

VGAN

$x$をDiscriminatorが判別したいデータとします。

Untitled (2).png

Discriminatorに与えるデータは$x$そのものではなく、$x$を与えられたEncoderが出力した潜在変数$z$です。

これがDiscriminatorに対するハンデとなります。

ここで用いられているDiscriminatorを以後VD(Variational Discriminator)と表記します。

Encoder

VDのボトルネックは、情報ボトルネック(Tishby & Za-slavsky, 2015)(参考)に基づいており、入力との相互情報量を最小化するために内部表現を正則化する技術である。直感的には,圧縮された表現は,元の入力に存在する無関係な"望ましくないもの"を無視することで,汎化を向上させることができます。

なんだかVAEみたい。(実際、VAEからヒントを得たらしいです)

$x$ ⇒ (Encoder) ⇒ $x'$ ⇒ (Conv) ⇒ $\mu, log\sum$ ⇒ $z$ ⇒ (Discriminator) ⇒ 0/1

VDB

Discriminatorの拡張部分

\mu = Conv_{mean}(x) \\
log\sum = Conv_{logvar}(x) \\
noize = [N(0,1),]*|\mu| \\
z = e^{0.5*log\sum}*noize + \mu
mean = self.conv_mean(x).view(-1, 128 * 7 * 7) # mean
logvar = self.conv_logvar(x).view(-1, 128 * 7 * 7) # logvar
noise = torch.randn(mean.size(), device=device)
z = (0.5 * logvar).exp() * noise + mean

VDBの目的関数

GANの目的関数:

$$\underset{G}{max}~\underset{D}{min}~~E_{x〜p^*(x)}\ [-logD(x)]+E_{x〜G(x)}[-log(1-D(G(x))]$$

Encoderを導入することで情報ボトルネックを取り入れます。

Variational Discriminator Bottleneck(VDB)の目的関数:

$$\underset{D,E}{min}~\underset{\beta\geq0}{max}~~[E_{x〜p^*(x)}\ [-logD(x)]]+E_{x〜G(x)}[E_{z〜E(z|x)}[-log(1-D(G(x))]]+\beta(E_{x〜\bar{p}(x)}[KL[E(z|x)||r(z)]]-I_c)$$

二重勾配降下法を用いて$\beta$を適応的に更新し、相互情報$I_c$の特定の制約条件を強制します。

D,E \leftarrow \underset{D,E}{arg~min}~\mathcal{L}(D,E,\beta)\\
\beta \leftarrow max(0,~\beta+\alpha_{\beta}(E_{x〜\bar{p}(x)}[KL[E(z|x)||r(z)]]-I_c)) \\ \\
\mathcal{L}(D,E,\beta) \leftarrow~[E_{x〜p^*(x)}\ [-logD(x)]]+E_{x〜G(x)}[E_{z〜E(z|x)}[-log(1-D(G(x))]]+\beta(E_{x〜\bar{p}(x)}[KL[E(z|x)||r(z)]]-I_c)

$\alpha_\beta$はステップサイズ。

ちなみに、Generatorの目的関数$E_{x〜G(x)}[E_{z〜E(z|x)}[-log(1-D(G(x))]]$はKLペナルティを除いていますが、Encoderの分布の平均$μ_{E(x)}$でDを評価することで期待値を近似することで十分らしいです(よくわからない)。

Discriminatorの損失関数

BCE\_loss = BCE(out, label) \\
normal\_D\_loss = mean(BCE\_loss) \\
KLdiv\_loss = -0.5\times mean(1+log\sum-\mu^2-e^{log\sum}) \\
KLdiv\_loss = mean(KLdiv\_loss - I_c) \\ \\
loss = normal\_D\_loss + \beta * KLdiv\_loss
I_c = 0.1
beta = 0 #1.0

def VDB_loss(out, label, mean, logvar, beta):
    normal_D_loss = torch.mean(F.binary_cross_entropy(out, label))
    
    kldiv_loss = - 0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
    kldiv_loss = kldiv_loss.mean() - I_c
    final_loss = normal_D_loss + beta * kldiv_loss
    
    return final_loss, kldiv_loss.detach()

$\beta$を更新する。

alpha = 1e-5

D_real, mean, logvar = D(real_images)
D_real_loss, loss_kldiv_real = D_criterion(D_real, y_real, mean, logvar, beta)

fake_images = G(z)
D_fake, mean, logvar = D(fake_images.detach())
D_fake_loss, loss_kldiv_fake = D_criterion(D_fake, y_fake, mean, logvar, beta)

loss_kldiv = loss_kldiv_real.item() + loss_kldiv_fake.item()
beta = max(0.0, beta + alpha * loss_kldiv)

最後の更新式を見る感じ、通常のD_lossとKLダイバージェンスlossを$\beta$で乗じた値の和がVDの損失関数になりそうです。

ソースコード(Pytorch)

vgan-mnist.ipynb

vgan-cifar10.ipynb

githubにあったコードを参考にしました。

生成画像

epoch = 100
ep100 (1).png

loss曲線

曲線 (1).png

0
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?