GANとは
敵対的生成ネットワークGAN(Generative adversarial networks)とは生成器(Generator)と識別器(Discriminator)の2つのニューラルネットワークからなるシステムです。生成器は識別器を欺くように学習し、識別器は生成器の作った偽物と本物を識別するように学習を進めることで、結果として精度の高い偽物を生み出す事ができるようになります。
生成器と識別器が相反した目的のもとに学習する様が敵対的と呼ばれる所以のようです。
DCGAN
DCGAN(Deep Convolutional GAN)はGANで最も基本的なモデルになります。生成器は複数の転置畳み込み層からなり、ノイズを複数の転置畳み込み層で拡大することによりフェイク画像を生成します。識別機は複数の畳み込み層から形成され活性化関数にはLeakyReLUが使われます。
MNIST
MNISTは0~9の手書き数字の画像データセットで、主に画像認識を目的とした機械学習の初心者向けチュートリアルでよく使われているデータセットです。
実装
それではDCGANを実装していきます。GPUを使わないコードにしているので、少し時間はかかりますがPythonと必要なライブラリが入っていれば動かせると思います。
ライブラリのインポート
import matplotlib.pyplot as plt
import numpy as np
import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
定数の設定
n_epoch = 20
batch_size =64
lr = 0.001
z_dim = 100
Generatorクラス
class Generator(nn.Module):
def __init__(self, z_dim=100, ngf=128, nc=1):
super().__init__()
self.convt1 = self.conv_trans_layers(100, 512, 3, 1, 0)
self.convt2 = self.conv_trans_layers(512, 256, 3, 2, 0)
self.convt3 = self.conv_trans_layers(256, 128, 4, 2, 1)
self.convt4 = nn.Sequential(
nn.ConvTranspose2d(ngf, nc, 4, 2, 1),
nn.Tanh()
)
@staticmethod
def conv_trans_layers(in_channels, out_channels, kernel_size, stride, padding):
net = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
return net
def forward(self, x):
out = self.convt1(x)
out = self.convt2(out)
out = self.convt3(out)
out = self.convt4(out)
return out
Discriminatorクラス
class Discriminator(nn.Module):
def __init__(self, nc=1, ndf=128):
super().__init__()
self.conv1 = self.conv_layers(nc, ndf, has_batch_norm=False)
self.conv2 = self.conv_layers(ndf, 2*ndf)
self.conv3 = self.conv_layers(2*ndf, 4*ndf, 3, 2, 0)
self.conv4 = nn.Sequential(
nn.Conv2d(4*ndf, 1, 3, 1, 0),
nn.Sigmoid()
)
@staticmethod
def conv_layers(in_channels, out_channels, kernel_size=4, stride=2, padding=1,
has_batch_norm=True):
layers = [
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
]
if has_batch_norm:
layers.append(nn.BatchNorm2d(out_channels))
layers.append(nn.LeakyReLU(0.2, inplace=True))
net = nn.Sequential(*layers)
return net
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
out = self.conv3(out)
out = self.conv4(out)
return out
重みの初期化処理
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
データの前処理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, ), (0.5, ))
])
MNISTの読み込みとデータローダーの設定
dataset = dset.MNIST(root='./data/', download=True, train=True, transform=transform)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)
GeneratorとDiscriminatorの生成
netG = Generator(z_dim=z_dim, ngf=128)
netG.apply(weights_init)
netD = Discriminator(nc=1, ndf=128)
netD.apply(weights_init)
損失関数と最適化
criterion = nn.BCELoss()
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5)
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5)
GeneratorとDiscriminatorの学習とフェイク画像の生成
for epoch in range(n_epoch):
for i, (real_imgs, labels) in enumerate(tqdm.tqdm(dataloader, position=0)):
batch_size = real_imgs.size()[0]
noise = torch.randn(batch_size, z_dim, 1, 1)
shape = (batch_size, 1, 1, 1)
labels_real = torch.ones(shape)
labels_fake = torch.zeros(shape)
netD.zero_grad()
output = netD(real_imgs)
lossD_real = criterion(output, labels_real)
fake_imgs = netG(noise)
output = netD(fake_imgs.detach())
lossD_fake = criterion(output, labels_fake)
lossD = lossD_real + lossD_fake
lossD.backward()
optimizerD.step()
netG.zero_grad()
output = netD(fake_imgs)
lossG = criterion(output, labels_real)
lossG.backward()
optimizerG.step()
grid_imgs = vutils.make_grid(fake_imgs[:24].detach())
grid_imgs_arr = grid_imgs.numpy()
plt.imshow(np.transpose(grid_imgs_arr, (1, 2, 0)))
plt.show()
vutils.save_image(fake_imgs, './result/{}.jpg'.format(epoch))
出力結果
(途中、省略)
初回の出力では怪しい数字も多いですが、20回目の出力は8~9割くらいが元のMNISTに近い数字が生成出来ていることがわかります。