5
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

CGAN(条件付きGAN)のシンプルな理解と実装

Posted at

CGAN(Conditional GAN)(条件付きGAN)を私なりになるべくシンプルに理解してなるべくシンプルに実装しました。

GAN

敵対的生成ネットワークというやつです。生成器と識別器の二つのニューラルネットワークを敵対させて学習させます。細かい話にはここでは触れません。また本記事と同じようなテンションで私が書いた記事があります。用語や実装の部分はここから引き継いでいる部分があるので、気になる人は読んでみてください。

GANのシンプルな理解と実装 - Qiita

CGAN

Conditional GANのことです。条件付きGANという意味になります。普通のGANでは、用意したデータセットに似たデータを乱数(ノイズ)から生成しますが、その生成するデータに制限を加えるといった感じですかね。例えばMNISTの手書き数字生成だったら、ランダムな数字ではなく、指定した数字を生成できるようになります。text-to-imageのモデルもCGANをベースに作ることが出来ます(textを条件とする感じ)。

GANからの変更点

では、そんなCGANを実現するアルゴリズムについて考えていくわけですが、実は普通のGANとほとんど変わりません。変更点は以下の三つのみです。

生成器・識別器が条件を考慮できるようにする

生成器・識別器が推論時に「条件」に何かしらの形で触れるようにするという意味です。例えば、条件を表したベクトルを適当な層にくっつけるとか。手書き数字であれば、対応するクラス(数字)のone-hotベクトルなんかが条件ベクトルになりますね。それを入力層とかに入力されるベクトルの末尾にくっつけるといった感じです。
media 0310.jpg
あと生成器の場合、最初に与えるノイズに情報を持たせるのもありですね。対応する箇所だけ大きくするみたいな。
media 0312.jpg

こんな感じで、生成器・識別器が「条件」を何らかの形で考慮できるようにする必要があります。

本物のデータセットに適切な条件を紐づける

普通のGANでは本物を用意するだけでしたが、CGANでは適切な条件を紐づけておく必要があります。理由はすぐ下で述べています。

条件が一致した本物のデータのみを本物とする

日本語が分かりづらいですが、気にしなーい。

これは識別器の学習時の話です。普通のGANでは生成器が生成したもの以外を全て本物としていましたが、CGANでは「正しい条件が紐づいているもの」という制約も加えます。

もう少し具体的に見ていきましょう。識別器の学習時は以下の三つのデータを用意します。

  • 正しい条件が紐づいている本物
  • 間違った条件が紐づいている本物
  • 生成器で生成した偽物

この三つの中で、一番上の「正しい条件が紐づいている本物」のみを本物、それ以外を偽物とし、適切な分類が行えるように識別器を学習させます。こうすることで、生成器は条件に沿ったデータが生成できるように学習するようになります。

実装

理論の説明が終わったので、実装していきましょう。今回は、MNISTを用いて指定した数字の手書き画像生成するモデルを作ります。
言語はPython、フレームワークはPyTorchを用います。

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms
from IPython.display import display

batch_size = 64
nz = 100
noize_std = 0.7
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

noise_stdは正規分布に従ってノイズを生成する際の標準偏差です。色々試した結果こうなりました。

MNIST

手書き数字のデータセットです。こんな画像。
mnist_ex.jfif
指定した数字でこんな感じの画像を生成したい訳です。

torchvisionから読み込みます。

dataset = MNIST(
    root="data",
    train=True,
    download=True,
    transform=transforms.ToTensor()
)

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True
)

sample_x, _ = next(iter(dataloader))
n_classes = len(torch.unique(dataset.targets)) # 10
w, h = sample_X.shape[-2:]                     # (28, 28)
image_size = w * h                             # 784

識別器

全結合層とReLUで作ります。条件はone-hotベクトルとして受け取り、入力するベクトルの末尾に結合します。

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(image_size + n_classes, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        self._eye = torch.eye(n_classes, device=device) # 条件ベクトル生成用の単位行列

    def forward(self, x, labels):
        labels = self._eye[labels] # ラベル(条件)をone-hotベクトルに
        x = x.view(batch_size, -1) # 画像を1次元に
        x = torch.cat([x, labels], dim=1) # 画像と条件ベクトルを結合
        y = self.net(x)
        return y

生成器

全結合層とReLUとバッチ正規化で作ります。また、今回は条件の情報をノイズに持たせることにします。ということで、工夫を加える要素は特にありません。

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            self._linear(nz, 128),
            self._linear(128, 256),
            self._linear(256, 512),
            nn.Linear(512, image_size),
            nn.Sigmoid()
        )

    def _linear(self, input_size, output_size):
        return nn.Sequential(
            nn.Linear(input_size, output_size),
            nn.BatchNorm1d(output_size),
            nn.ReLU()
        )

    def forward(self, x):
        x = x.view(-1, nz)
        y = self.net(x)
        y = y.view(-1, 1, w, h) # 784 -> 1x28x28
        return y

ノイズ

生成器に与えるノイズに条件の情報を持たせます。どうやって持たせるかというと、対応する部分の値を大きくします。今回ノイズは100次元なので、条件(生成したい数字)が0の時は0~9次元、1の時は10~19次元、2の時は20~29次元...の部分が大きな値になるようにします。コードで書くとこんな感じで、

eye = torch.eye(n_classes, device=device)
def make_noise(labels):
    labels = eye[labels]
    labels = labels.repeat_interleave(nz // n_classes, dim=1)
    z = torch.normal(0, noise_std, size=(len(labels), nz), device=device)
    z = z + labels
    return z

実際にノイズを生成してグラフにするとこんな感じです。
output1.png
これだと少し分かり辛いですが、移動平均をとると分かり易くなります。
output2.png
対応する部分が大きくなっているのが分かるとおもいます。

これ(移動平均をとる前のベクトル)を生成器に与えて画像を生成します。

学習

ノーマルなGAN同様、識別器の学習 → 生成器の学習 を繰り返します。識別器の学習部分だけ少し変わる感じですね。

識別器の学習

fake_labels = torch.zeros(batch_size, 1).to(device) # 偽物のラベル
real_labels = torch.ones(batch_size, 1).to(device) # 本物のラベル
criterion = nn.BCELoss() # バイナリ交差エントロピー

最初に載せた記事とは0, 1が逆になっています。なお特に重要な意味はありません。一般的にそうなっていることが多いからというだけです。

具体的な学習部分はこんな感じです

z = make_noise(labels) # ノイズを生成
fake = netG(z) # 偽物を生成
false_labels = make_false_labels(labels) # 間違ったラベルを生成

pred_fake = netD(fake, labels) # 偽物を判定
pred_real_true = netD(X, labels) # 本物&正しいラベルを判定
pred_real_false = netD(X, false_labels) # 本物&間違ったラベルを判定

# 誤差を計算
loss_fake = criterion(pred_fake, fake_labels)
loss_real_true = criterion(pred_real_true, real_labels)
loss_real_false = criterion(pred_real_false, fake_labels)
lossD = loss_fake + loss_real_true + loss_real_false # 全ての和をとる

lossD.backward() # 逆伝播
optimD.step() # パラメータ更新

あとで定義しますが、各変数の説明は以下になります。

  • make_noise(): ノイズを生成する関数
  • make_false_labels(): 間違ったラベルを生成する関数
  • netG: 生成器
  • netD: 識別器
  • x: 本物のミニバッチ
  • labels: 正しいラベル
  • optimD: 識別器の最適化関数

生成器の学習

ノーマルなGANと同じです。

fake = netG(z) # 偽物を生成
pred = netD(fake, labels) # 偽物を判定
lossG = criterion(pred, real_labels) # 誤差を計算
lossG.backward() # 逆伝播
optimG.step() # パラメータ更新

誤差は交差エントロピーで求めるようにしました。

関数化

ここまでをまとめてみましょう。

eye = torch.eye(n_classes, device=device)
def make_noise(labels):
    labels = eye[labels]
    labels = labels.repeat_interleave(nz // n_classes, dim=1)
    z = torch.normal(0, noise_std, size=(len(labels), nz), device=device)
    z = z + labels
    return z

# 画像描画
def write(netG, n_rows=1, size=64):
    n_images = n_rows * n_classes
    z = make_noise(torch.tensor(list(range(n_classes)) * n_rows))
    images = netG(z)
    images = transforms.Resize(size)(images)
    img = torchvision.utils.make_grid(images, n_images // n_rows)
    img = transforms.functional.to_pil_image(img)
    display(img)

# 間違ったラベルの生成
def make_false_labels(labels):
    diff = torch.randint(1, n_classes, size=labels.size(), device=device)
    fake_labels = (labels + diff) % n_classes
    return fake_labels

fake_labels = torch.zeros(batch_size, 1).to(device) # 偽物のラベル
real_labels = torch.ones(batch_size, 1).to(device) # 本物のラベル
criterion = nn.BCELoss() # バイナリ交差エントロピー

def train(netD, netG, optimD, optimG, n_epochs, write_interval=1):
    # 学習モード
    netD.train()
    netG.train()

    for epoch in range(1, n_epochs+1):
        for X, labels in dataloader:
            X = X.to(device) # 本物の画像
            labels = labels.to(device) # 正しいラベル
            false_labels = make_false_labels(labels) # 間違ったラベル

            # 勾配をリセット
            optimD.zero_grad()
            optimG.zero_grad()

            # Discriminatorの学習
            z = make_noise(labels) # ノイズを生成
            fake = netG(z) # 偽物を生成
            pred_fake = netD(fake, labels) # 偽物を判定
            pred_real_true = netD(X, labels) # 本物&正しいラベルを判定
            pred_real_false = netD(X, false_labels) # 本物&間違ったラベルを判定
            # 誤差を計算
            loss_fake = criterion(pred_fake, fake_labels)
            loss_real_true = criterion(pred_real_true, real_labels)
            loss_real_false = criterion(pred_real_false, fake_labels)
            lossD = loss_fake + loss_real_true + loss_real_false # 全ての和をとる
            lossD.backward() # 逆伝播
            optimD.step() # パラメータ更新

            # Generatorの学習
            fake = netG(z) # 偽物を生成
            pred = netD(fake, labels) # 偽物を判定
            lossG = criterion(pred, real_labels) # 誤差を計算
            lossG.backward() # 逆伝播
            optimG.step() # パラメータ更新

        print(f'{epoch:>3}epoch | lossD: {lossD:.4f}, lossG: {lossG:.4f}')
        if write_interval and epoch % write_interval == 0:
            write(netG)

これで必要な関数が揃いました

学習

では実際に学習させてみましょう。

netD = Discriminator().to(device)
netG = Generator().to(device)
optimD = optim.Adam(netD.parameters(), lr=0.0002)
optimG = optim.Adam(netG.parameters(), lr=0.0002)
n_epochs = 5

print('初期状態')
write(netG)
train(netD, netG, optimD, optimG, n_epochs)

これを実行するとこうなります。
media 0299.jpg
各epochごとに表示している画像は、0~9の条件を左から順に与えて生成しています。この時点で、各条件(数字)の特徴はとらえられているように見えます。もう少し学習させてみましょう。+30epochs。

train(netD, netG, optimD, optimG, 30, 5)

経過はこうなりました。
media 0306.jpg

ちょっとずつ綺麗になっているように見えなくもないですが、ちょっと汚いですね。最終的に出来上がった生成器で画像を100枚生成してみます。
output3.png
2が全体的に汚いでしょうか。ただそれ以外は数字を識別できるものも多いので、まあ上手くいったのではないでしょうか。

CDCGAN

条件付きDCGANです。DCGANとはDeep Convolutional GANのことで、識別器と生成器にCNNを用いたGANです。こっちの方が画像がきれいになることが期待できます。実装してみましょう。といっても、生成器と識別器に畳み込み層を組み込むだけです。

識別器

畳み込み層と全結合層で作ります。条件の与え方は色々考えられそうですが、今回は、畳み込み層で得た特徴量にone-hotベクトルをくっつけて全結合層に渡すようにします。図にするとこんな感じです。
media 0316.jpg
コードは以下。

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            self._conv_layer(1, 16, 4, 2, 1),
            self._conv_layer(16, 32, 4, 2, 1),
            self._conv_layer(32, 64, 3, 2, 0),
            nn.Conv2d(64, 128, 3, 1, 0),
            nn.Flatten()
        )
        self.fc = nn.Sequential(
            nn.Linear(128 + n_classes, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
        self._eye = torch.eye(n_classes, device=device) # 条件ベクトル生成用の単位行列

    def _conv_layer(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x, labels):
        x = self.conv(x) # 特徴抽出
        labels = self._eye[labels] # 条件(ラベル)をone-hotベクトルに
        x = torch.cat([x, labels], dim=1) # 画像と条件を結合
        y = self.fc(x)
        return y

畳み込み → バッチ正規化 → ReLU の流れを_conv_layer()関数にまとめています。forward()では、畳み込みで特徴量を得た後に条件ベクトルをくっつけて全結合層に投げているのが分かると思います。

生成器

転置畳み込みで作ります。また条件は先ほど同様ノイズに持たせるので特に工夫はいりません。

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            self._convT(nz, 128, 3, 1, 0),
            self._convT(128, 64, 3, 2, 0),
            self._convT(64, 32, 4, 2, 1),
            nn.ConvTranspose2d(32, 1, 4, 2, 1),
            nn.Sigmoid()
        )

    def _convT(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        x = x.view(-1, nz, 1, 1)
        y = self.net(x)
        return y

転置畳み込み → バッチ正規化 → ReLU の流れを_convT()関数にまとめています。またforward()では最初にノイズの形状を変えています。先ほど実装したmake_noise()関数では(バッチサイズ, ノイズの次元数)のノイズが生成されてしまい転置畳み込みに渡せないので、(バッチサイズ, ノイズの次元数, 1, 1)に整形しています。

学習

ではこれで学習させてみましょう。

netD = Discriminator().to(device)
netG = Generator().to(device)
optimD = optim.Adam(netD.parameters(), lr=0.0002)
optimG = optim.Adam(netG.parameters(), lr=0.0002)
n_epochs = 35
train(netD, netG, optimD, optimG, n_epochs)

最終的に出来上がった生成器で画像を生成するとこんな感じです。
output3.png

形が崩れているものもあるけど、線は綺麗になりましたね。

おまけ

生成器に与える条件をいじって遊ぶ。

  • 偏りを変化させる

今回は正規分布に従ってノイズを生成した後、生成したい数字に対応する箇所に1を足すことで条件の情報をノイズに持たせた。ではこの1を変化させるとどうなるか。
まずは大きくしてみる(1 -> 10)。生成するのは5。
output1.png
特徴はしっかり捉えられている。そして乱数の影響が少なくなるので、生成される画像はほぼ同じになる。

次に小さくしてみる(1 -> 0.5)
output2.png
条件の制約が弱まり、バラバラな画像になった。特徴はぎりぎり捉えられているのかな。

  • 二つの数字を条件に入れる

生成したい数字に対応する箇所が大きくなるようにしているが、複数の個所を大きくしたらどうなるのか。
3と6を大きくしてみよう。
output3.png
3っぽいのと、6っぽいのと、それらが混ざったような画像が生成された。

  • 何も与えない

ただの乱数をそのまま与えてみる。
output4.png
ごちゃごちゃ。乱数によって偶然生まれた偏りによってなにかしらの数字っぽくはなる様子。

  • 全部与える

全部の個所を大きくしてみる
output5.png
0, 6, 8, 0のどれかに近くなった。

おわり

終わりです。上記のプログラムはこちらのノートブックにまとめました。気になる人は見てみてください。

5
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?