DCGANとは
GAN は提案された論文では Generator も Discriminator も全結合層のみで構成されていましたが、両者に畳み込み層を使用することでより自然な画像を生成できるようにしたものを Deep Convolutional GAN(DCGAN) と言います。
データセットの準備
CIFAR10のデータセットを使用します。
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
import torchvision.transforms as transforms
np.random.seed(0)
torch.manual_seed(0)
transform = transforms.ToTensor()
train_data = datasets.CIFAR10(root='data', train=True,
download=True, transform=transform)
batch_size = 128
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)
モデルの定義
Discriminator
scriminator は画像を入力とし、その入力画像が偽物か本物であるかを識別する、単純な分類器です。今回においては、入力画像は 1 x 28 x 28 となり、Conv2d -> BatchNorm2d -> LeakyReLU の順に処理を通していきます。また、stride=2 とすることで Pooling を行わずとも畳み込みだけで画像サイズを半分にするように実装します。
class Discriminator(nn.Module):
def __init__(self, conv_dim=32):
super().__init__()
self.conv_dim = conv_dim
self.conv = nn.Sequential(
# input: 32 x 32
nn.Conv2d(in_channels=3, out_channels=conv_dim, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# 16 x 16
nn.Conv2d(conv_dim, conv_dim*2, 4, 2, 1),
nn.BatchNorm2d(conv_dim * 2),
nn.LeakyReLU(0.2, inplace=True),
# 8 x 8
nn.Conv2d(conv_dim*2, conv_dim*4, 4, 2, 1),
nn.BatchNorm2d(conv_dim * 4),
nn.LeakyReLU(0.2, inplace=True),
# output: 4 x 4
)
self.fc = nn.Linear(conv_dim*4*4*4, 1)
def forward(self, x):
out = self.conv(x)
out = out.view(out.size(0), -1)
return self.fc(out)
Generator
Generator では、出力層の活性化関数に tanh を用います。
したがって、Generator の出力は -1 ~ 1 の値を取ることになるため、 Disciriminator の学習時には、本物画像の入力データも -1 ~ 1 のレンジをとるようにリスケーリングして上げる必要があります。
また、畳み込みの逆処理として転置畳み込みを行います。畳み込みでは kernel_size=4, stride=2 として画像サイズを縮小していましたので、転置畳み込みでも同様のハイパーパラメータで拡大を行います。転置畳み込みは PyTorch では ConvTranspose2d() で実装することができます。
class Generator(nn.Module):
def __init__(self, z_dim, conv_dim=32):
super().__init__()
self.conv_dim = conv_dim
self.fc = nn.Linear(z_dim, conv_dim*4*4*4)
self.deconv = nn.Sequential(
nn.ConvTranspose2d(conv_dim*4, conv_dim*2, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(conv_dim*2),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(conv_dim*2, conv_dim, 4, 2, 1, bias=False),
nn.BatchNorm2d(conv_dim),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(conv_dim, 3, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, x):
out = self.fc(x)
out = out.view(-1, self.conv_dim*4, 4, 4)
out = self.deconv(out)
return out
損失関数
Discriminator Loss
Discriminator Loss
Discriminator は本物画像と偽物画像、それぞれを入力として出力を行います。それぞれの画像に対して Loss を計算し、その合計を Discriminator Loss とします。偽物画像は、Generator が生成したものを指します。
実装上では、 d_loss = d_real_loss + d_fake_loss とします。
ここでの Loss の計算の際には、本物画像を入力として、real ならば 1, fake ならば 0 を正解ラベルとして Loss の算出を行います。正常な教師ラベルを使用するということです。
Generator Loss
Generator の場合は、偽物を本物として騙すことを目標とするため、Loss の計算の際には教師ラベルの値を反転させることがポイントです。
つまり、real ならば 0, fake ならば 1 という値を教師ラベルとして、損失関数の計算を行うことになります。
2 値交差エントロピー誤差は、nn.BCEWithLogitsLoss() として PyTorch で用意されています。
def real_loss(D_out, smooth=False):
batch_size = D_out.size(0)
if smooth:
labels = torch.ones(batch_size)*0.9
else:
labels = torch.ones(batch_size)
criterion = nn.BCEWithLogitsLoss()
loss = criterion(D_out.squeeze(), labels.to(device))
return loss
def fake_loss(D_out):
batch_size = D_out.size(0)
labels = torch.zeros(batch_size)
criterion = nn.BCEWithLogitsLoss()
loss = criterion(D_out.squeeze(), labels.to(device))
return loss
学習手順
GAN の学習は、Generator と Discriminator を交互に行います。
Discriminator
1.本物画像で discriminator loss を計算
2.偽物画像を generator で生成
3.偽物画像で discriminator loss を計算
4.2 つの discriminator loss を合算
5.合算した loss を元に逆伝播・パラメータ更新
Generator
1.偽物画像を generator で生成
2.偽物画像で disciminator loss を計算( label を反転させる )
3.逆伝播・パラメータ更新
Generator の学習の際には、loss の計算の際に label を反転させる点がポイントです。
また、Generator の出力は -1 ~ 1 の値を取ることになるため、 Disciriminator の学習時には、本物画像の入力データも -1 ~ 1 のレンジをとるようにリスケーリングして上げる必要があります。
正規化された画像は 0 ~ 1 の範囲をとりますので、2 倍して 1 を引けば -1 ~ 1 の範囲とすることができます。
num_epochs = 50
z_dim = 100
lr = 1e-4
conv_dim = 32
device = 'cuda' if torch.cuda.is_available() else 'cpu'
D = Discriminator(conv_dim).to(device)
G = Generator(z_dim=z_dim, conv_dim=conv_dim).to(device)
d_optim = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
g_optim = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
for epoch in range(num_epochs):
for idx, (real_imgs, _) in enumerate(train_loader):
batch_size = real_imgs.size(0)
# rescaling 0~1 -> -1~1
real_imgs = real_imgs * 2 - 1
real_imgs = real_imgs.to(device)
# =============================
# TRAIN THE DISCRIMINATOR
# =============================
d_optim.zero_grad()
# 本物画像で discriminator loss を計算
D_real = D(real_imgs)
d_real_loss = real_loss(D_real)
# 偽物画像を generator で生成
z = np.random.uniform(-1, 1, size=(batch_size, z_dim))
z = torch.from_numpy(z).float().to(device)
fake_imgs = G(z)
# 偽物画像で discriminator loss を計算
D_fake = D(fake_imgs)
d_fake_loss = fake_loss(D_fake)
# 合算した loss を元に逆伝播・パラメータ更新
d_loss = d_real_loss + d_fake_loss
d_loss.backward()
d_optim.step()
# =============================
# TRAIN THE GENERATOR
# =============================
g_optim.zero_grad()
# 偽物画像を generator で生成
z = np.random.uniform(-1, 1, size=(batch_size, z_dim))
z = torch.from_numpy(z).float().to(device)
fake_imgs = G(z)
# 偽物画像で disciminator loss を計算( label を反転させる )
D_fake = D(fake_imgs)
g_loss = real_loss(D_fake) # ここで fake_loss ではなく real_loss を用いる点がポイント
# 逆伝播・パラメータ更新
g_loss.backward()
g_optim.step()
# エポックごとにログを表示
print('Epoch [{:5d}/{:5d}] | d_loss: {:6.4f} | g_loss: {:6.4f}'.format(epoch+1, num_epochs, d_loss.item(), g_loss.item()))
Epoch [ 1/ 50] | d_loss: 0.2274 | g_loss: 2.7447
Epoch [ 2/ 50] | d_loss: 0.4706 | g_loss: 2.3281
Epoch [ 3/ 50] | d_loss: 0.4193 | g_loss: 3.3553
Epoch [ 4/ 50] | d_loss: 0.2060 | g_loss: 2.5614
Epoch [ 5/ 50] | d_loss: 0.2394 | g_loss: 3.4979
Epoch [ 6/ 50] | d_loss: 0.2210 | g_loss: 3.4103
Epoch [ 7/ 50] | d_loss: 0.2438 | g_loss: 2.8378
Epoch [ 8/ 50] | d_loss: 0.3433 | g_loss: 3.1535
Epoch [ 9/ 50] | d_loss: 0.2676 | g_loss: 3.3410
Epoch [ 10/ 50] | d_loss: 0.2905 | g_loss: 3.3490
Epoch [ 11/ 50] | d_loss: 0.2241 | g_loss: 3.4354
Epoch [ 12/ 50] | d_loss: 0.3221 | g_loss: 3.3013
Epoch [ 13/ 50] | d_loss: 0.2197 | g_loss: 2.7238
Epoch [ 14/ 50] | d_loss: 0.2134 | g_loss: 3.2118
Epoch [ 15/ 50] | d_loss: 0.1457 | g_loss: 2.6087
Epoch [ 16/ 50] | d_loss: 0.2674 | g_loss: 3.0657
Epoch [ 17/ 50] | d_loss: 0.1635 | g_loss: 2.4765
Epoch [ 18/ 50] | d_loss: 0.2818 | g_loss: 3.4029
Epoch [ 19/ 50] | d_loss: 0.2594 | g_loss: 2.6966
Epoch [ 20/ 50] | d_loss: 0.5297 | g_loss: 4.6321
Epoch [ 21/ 50] | d_loss: 0.3596 | g_loss: 3.6594
Epoch [ 22/ 50] | d_loss: 0.2560 | g_loss: 2.8364
Epoch [ 23/ 50] | d_loss: 0.2470 | g_loss: 3.4055
Epoch [ 24/ 50] | d_loss: 0.2985 | g_loss: 3.4396
Epoch [ 25/ 50] | d_loss: 0.3041 | g_loss: 3.6556
Epoch [ 26/ 50] | d_loss: 0.3278 | g_loss: 2.5795
Epoch [ 27/ 50] | d_loss: 0.2567 | g_loss: 2.9533
Epoch [ 28/ 50] | d_loss: 0.2527 | g_loss: 3.3458
Epoch [ 29/ 50] | d_loss: 0.2091 | g_loss: 2.9839
Epoch [ 30/ 50] | d_loss: 0.8663 | g_loss: 1.1759
Epoch [ 31/ 50] | d_loss: 0.1809 | g_loss: 3.2202
Epoch [ 32/ 50] | d_loss: 0.1543 | g_loss: 2.9448
Epoch [ 33/ 50] | d_loss: 0.1724 | g_loss: 3.0919
Epoch [ 34/ 50] | d_loss: 0.2185 | g_loss: 3.4930
Epoch [ 35/ 50] | d_loss: 0.1586 | g_loss: 2.9536
Epoch [ 36/ 50] | d_loss: 0.2682 | g_loss: 4.3134
Epoch [ 37/ 50] | d_loss: 0.2298 | g_loss: 3.5478
Epoch [ 38/ 50] | d_loss: 0.2628 | g_loss: 3.9757
Epoch [ 39/ 50] | d_loss: 0.1917 | g_loss: 3.4522
Epoch [ 40/ 50] | d_loss: 0.1837 | g_loss: 3.5620
Epoch [ 41/ 50] | d_loss: 0.1360 | g_loss: 3.4125
Epoch [ 42/ 50] | d_loss: 0.2444 | g_loss: 3.7329
Epoch [ 43/ 50] | d_loss: 0.2900 | g_loss: 4.5762
Epoch [ 44/ 50] | d_loss: 0.8544 | g_loss: 2.3379
Epoch [ 45/ 50] | d_loss: 0.1917 | g_loss: 3.6750
Epoch [ 46/ 50] | d_loss: 0.1782 | g_loss: 3.6485
Epoch [ 47/ 50] | d_loss: 0.9766 | g_loss: 2.0075
Epoch [ 48/ 50] | d_loss: 0.1446 | g_loss: 2.8509
Epoch [ 49/ 50] | d_loss: 0.2435 | g_loss: 4.2065
Epoch [ 50/ 50] | d_loss: 0.2663 | g_loss: 4.3056
画像を生成
学習した Generator を用いて、ランダムに生成したノイズベクトルから画像を生成します。
# 100 次元のノイズを 25 個作成
z_dim = 100
z = np.random.uniform(-1, 1, size=(25, z_dim))
z = torch.from_numpy(z).float().to(device)
G.eval().to(device)
# Generator に通して偽物画像を生成する
fake_images = G(z)
fake_images.size()
torch.Size([25, 3, 32, 32])
生成した 25 枚の画像を表示します。
plt.figure(figsize=(12, 12))
for i in range(25):
img = fake_images[i].cpu().detach().numpy()
img = np.transpose(img, (1, 2, 0))
img = ((img + 1) * 255 / 2).astype(np.uint8) # rescaling
plt.subplot(5, 5, i+1)
plt.axis('off')
plt.imshow(img)