LoginSignup
3
3

More than 3 years have passed since last update.

RealnessGANを実装してみた

Posted at

RealnessGANという新しいGANがあるようですが、日本語の情報がほとんどなさそうなので。
実装してCIFAR-10について学習させました。

論文はReal or Not Real, that is the Questionです。
論文の著者による実装も参考にしました。
ここの実装も参考になります。

RealnessGANを使うとDCGANでもきれいに学習できるらしいです。

CelebAデータセットの学習結果
https://github.com/kam1107/RealnessGAN/blob/master/images/CelebA_snapshot.png

概要

普通のGANではDiscriminator(識別器)の出力は「Realness(現実っぽさ)」を表すスカラー値です。
この論文では、Realnessの確率分布を出力するDiscriminatorを使うことが提案されています。

Discriminatorが出力する情報が増えることで、Generator(生成器)の学習もよりうまくいくらしいです。
論文によると、普通のDCGANの構造でも1024×1024の顔画像(FFHQデータセット)の学習に成功したとか。
FFHQデータセットの学習結果

記号の意味

  • $D$ Discriminator(識別器)
  • $G$ Generator(生成器)
  • $\boldsymbol{z}$ Generatorに入れるノイズ(潜在表現)
  • $\mathcal{A}_0$ 偽物の画像に対するAnchor($D$の正解として与えるRealnessの分布?)
  • $\mathcal{A}_1$ 本物の画像に対するAnchor
  • $p_{\mathrm{data}}(\boldsymbol{x})$ データセットからランダムに選んで画像$\boldsymbol{x}$が出てくる確率?
  • $p_g(\boldsymbol{x})$ ランダムに選んだ$\boldsymbol{z}$から$G(\boldsymbol{z})$が画像$\boldsymbol{x}$になる確率?

手法

普通のGANのDiscriminatorが出力するのは連続なスカラー値「Realness」です。
一方、RealnessGANのDiscriminatorが出力するのはRealnessの離散確率分布であるようです。
例えば

D(\mbox{画像}) = 
\begin{bmatrix}
\mbox{画像のRealnessが }1.0\mbox{ である確率} \\
\mbox{画像のRealnessが }0.9\mbox{ である確率} \\
\vdots \\
\mbox{画像のRealnessが }-0.9\mbox{ である確率} \\
\mbox{画像のRealnessが }-1.0\mbox{ である確率} \\
\end{bmatrix}

みたいになるようです。
この離散化されたRealnessの値を論文ではOutcomeと呼んでいるようです。
確率分布はDiscriminatorの生の出力について、チャンネル方向のソフトマックスを取ることで求められるようです。

また、Realnessの確率分布についての正解データを論文ではAnchorと呼んでいるみたいです。
例えば

\mathcal{A}_0 = 
\begin{bmatrix}
\mbox{偽物の画像のRealnessが }1.0\mbox{ である確率} \\
\mbox{偽物の画像のRealnessが }0.9\mbox{ である確率} \\
\vdots \\
\mbox{偽物の画像のRealnessが }-0.9\mbox{ である確率} \\
\mbox{偽物の画像のRealnessが }-1.0\mbox{ である確率} \\
\end{bmatrix}
\mathcal{A}_1 = 
\begin{bmatrix}
\mbox{本物の画像のRealnessが }1.0\mbox{ である確率} \\
\mbox{本物の画像のRealnessが }0.9\mbox{ である確率} \\
\vdots \\
\mbox{本物の画像のRealnessが }-0.9\mbox{ である確率} \\
\mbox{本物の画像のRealnessが }-1.0\mbox{ である確率} \\
\end{bmatrix}

Realnessの値域、Anchorの分布などは自由にカスタムできるようです。

目的関数

論文によると目的関数は

\max_{G} \min_{D} V(G, D) =
\mathbb{E}_{\boldsymbol{x} \sim p_{\mathrm{data}}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{1} || D(\boldsymbol{x}) )] + 
\mathbb{E}_{\boldsymbol{x} \sim p_{g}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{0} || D(\boldsymbol{x}) )].
\tag{3}

らしいです。
そこからGenerator $G$の目的関数を取り出すと、

(G_{\mathrm{objective1}}) \quad
\min_{G} 
- \mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{0} || D(G(\boldsymbol{z}))].
\tag{18}

となるらしいのですが、これだと学習がうまくいかないらしいです。
そこで論文では$G$について、二つの目的関数が提案されています。

(G_{\mathrm{objective2}}) \quad
\min_{G} \quad
\mathbb{E}_{\boldsymbol{x} \sim p_{\mathrm{data}}, \boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( D(\boldsymbol{x}) || D(G(\boldsymbol{z}))]
- \mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{0} || D(G(\boldsymbol{z}))],
\tag{19}
(G_{\mathrm{objective3}}) \quad
\min_{G} \quad
\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{1} || D(G(\boldsymbol{z}))]
- \mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{0} || D(G(\boldsymbol{z}))].
\tag{20}

実験をしてみたところ、この$G$についての三つの目的関数のうち、式(19)の$G_{\mathrm{objective2}}$が一番よかったらしいです。

まとめると、

\begin{align}
\min_{D} & \quad
\mathbb{E}_{\boldsymbol{x} \sim p_{\mathrm{data}}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{1} || D(\boldsymbol{x}))] +  
\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{0} || D(G(\boldsymbol{z}) ))] \\
\min_{G} & \quad
\mathbb{E}_{\boldsymbol{x} \sim p_{\mathrm{data}}, \boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( D(\boldsymbol{x}) || D(G(\boldsymbol{z})))] -
\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}}[\mathcal{D}_{\mathrm{KL}}( \mathcal{A}_{0} || D(G(\boldsymbol{z})))]
\end{align}

となります。
$\mathbb{E}_ {\boldsymbol{x} \sim p_{\mathrm{data}}}[\cdots]$、$\mathbb{E}_ {\boldsymbol{z} \sim p_{\boldsymbol{z}}}[\cdots]$、$\mathbb{E}_ {\boldsymbol{x} \sim p_{\mathrm{data}}, \boldsymbol{z} \sim p_{\boldsymbol{z}}}[\cdots]$の部分はミニバッチの平均を取ればいい?

論文によるとAnchorを$\mathcal{A} _0 = [1, 0]$、$\mathcal{A} _1 = [0, 1]$とすると、目的関数が普通のGANと同じ形になるからRealnessGANは普通のGANを一般化したものであると考えられるらしいです。

雑多な情報

論文にはいくつかの議論と学習の工夫が載っているので、それらをまとめました。

Outcomeの数

Outcome(Discriminatorの出力の次元)を増やすほどいいらしいです。
Outcomeを増やした場合、Generator $G$を更新する回数を増やすといいらしい?

Anchorの選択

偽物の画像のAnchor$\mathcal{A} _0$と本物の画像のAnchor$\mathcal{A} _1$とのKLダイバージェンスが大きいほどいいらしいです。

特徴リサンプリング

Discriminatorの出力次元を2倍して、平均と標準偏差として正規分布からサンプリングすると性能が上がるらしいです。
Githubのソースだと標準偏差はそのまま使わず、$2$で割ってから指数をとっているみたいです(つまりもとの出力は分散の対数)。
特に学習の後半で学習が安定するようです。
下のコードではやっていません。

コード

CIFAR-10について学習させます。

realness_gan.py
import numpy
import torch
import torchvision

# KLダイバージェンスを計算する関数
# epsilonはlogでNaNが出ないように入れる
def kl_divergence(p, q, epsilon=1e-16):
    return torch.mean(torch.sum(p * torch.log((p + epsilon) / (q + epsilon)), dim=1))

# torch.nn.Sequentialにreshapeを入れられるように
class Reshape(torch.nn.Module):
    def __init__(self, *shape):
        super(Reshape, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.reshape(*self.shape)

class GAN:
    def __init__(self):
        self.noise_dimension = 100
        self.n_outcomes      = 20
        self.device          = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

        self.discriminator = torch.nn.Sequential(
            torch.nn.Conv2d( 3, 32, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(2),
            torch.nn.Conv2d(32, 32, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(2),
            torch.nn.Conv2d(32, 32, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(2),
            torch.nn.Conv2d(32, 32, 3, padding=1),
            torch.nn.ReLU(),
            Reshape(-1, 32 * 4 * 4),
            torch.nn.Linear(32 * 4 * 4, self.n_outcomes),
        ).to(self.device)
        self.generator = torch.nn.Sequential(
            torch.nn.Linear(self.noise_dimension, 32 * 4 * 4),
            Reshape(-1, 32, 4, 4),
            torch.nn.Conv2d(32, 32, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Upsample(scale_factor=2),
            torch.nn.Conv2d(32, 32, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Upsample(scale_factor=2),
            torch.nn.Conv2d(32, 32, 3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Upsample(scale_factor=2),
            torch.nn.Conv2d(32,  3, 3, padding=1),
            torch.nn.Sigmoid(),
        ).to(self.device)

        self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(),
                                                        lr=0.0001,
                                                        betas=[0.0, 0.9])
        self.generator_optimizer     = torch.optim.Adam(self.generator.parameters(),
                                                        lr=0.0001,
                                                        betas=[0.0, 0.9])

        # ここでAnchorを計算する
        # Githubにある著者の実装にならって乱数のヒストグラムを取る
        normal = numpy.random.normal(1, 1, 1000) # 平均+1、標準偏差1の正規分布
        count, _ = numpy.histogram(normal, self.n_outcomes, (-2, 2)) # -2から+2までのヒストグラムを取る
        self.real_anchor = count / sum(count) # 合計が1になるように正規化

        normal = numpy.random.normal(-1, 1, 1000) # 平均-1、標準偏差1の正規分布
        count, _ = numpy.histogram(normal, self.n_outcomes, (-2, 2))
        self.fake_anchor = count / sum(count)

        self.real_anchor = torch.Tensor(self.real_anchor).to(self.device)
        self.fake_anchor = torch.Tensor(self.fake_anchor).to(self.device)

    def generate_fakes(self, num):
        mean = torch.zeros(num, self.noise_dimension, device=self.device)
        std  = torch.ones(num, self.noise_dimension, device=self.device)
        noise = torch.normal(mean, std)
        return self.generator(noise)

    def train_discriminator(self, real):
        batch_size = real.shape[0]
        fake = self.generate_fakes(batch_size).detach()

        # Discriminatorの出力についてソフトマックスをとって確率にする
        real_feature = torch.nn.functional.softmax(self.discriminator(real), dim=1)
        fake_feature = torch.nn.functional.softmax(self.discriminator(fake), dim=1)

        loss = kl_divergence(self.real_anchor, real_feature) + kl_divergence(self.fake_anchor, fake_feature) # 論文の式(3)

        self.discriminator_optimizer.zero_grad()
        loss.backward()
        self.discriminator_optimizer.step()

        return float(loss)

    def train_generator(self, real):
        batch_size = real.shape[0]
        fake = self.generate_fakes(batch_size)

        real_feature = torch.nn.functional.softmax(self.discriminator(real), dim=1)
        fake_feature = torch.nn.functional.softmax(self.discriminator(fake), dim=1)

        # loss = -kl_divergence(self.fake_anchor, fake_feature) # 論文の式(18)
        loss = kl_divergence(real_feature, fake_feature) - kl_divergence(self.fake_anchor, fake_feature) # 論文の式(19)
        # loss = kl_divergence(self.real_anchor, fake_feature) - kl_divergence(self.fake_anchor, fake_feature) # 論文の式(20)

        self.generator_optimizer.zero_grad()
        loss.backward()
        self.generator_optimizer.step()

        return float(loss)

    def step(self, real):
        real = real.to(self.device)

        discriminator_loss = self.train_discriminator(real)
        generator_loss     = self.train_generator(real)

        return discriminator_loss, generator_loss

if __name__ == '__main__':
    transformer = torchvision.transforms.Compose([
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
    ])

    dataset = torchvision.datasets.CIFAR10(root='C:/datasets',
                                           transform=transformer,
                                           download=True)

    iterator = torch.utils.data.DataLoader(dataset,
                                           batch_size=128,
                                           drop_last=True)

    gan = GAN()
    n_steps = 0

    for epoch in range(1000):
        for iteration, data in enumerate(iterator):
            real = data[0].float()
            discriminator_loss, generator_loss = gan.step(real)

            print('epoch : {}, iteration : {}, discriminator_loss : {}, generator_loss : {}'.format(
                epoch, iteration, discriminator_loss, generator_loss
            ))

            n_steps += 1

            if iteration == 0:
                fakes = gan.generate_fakes(64)
                torchvision.utils.save_image(fakes, 'out/{}.png'.format(n_steps))

結果

0エポック目(1ステップ目)
1.png

10エポック目(3901ステップ目)
3901.png

100エポック目(39001ステップ目)
39001.png

500エポック目(195001ステップ目)
195001.png

この実装だとBatch NormalizationSpectral Normalization特徴リサンプリングも使っていないですが、まあまあ生成できているようです。

3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3