4
1

More than 1 year has passed since last update.

pytorchでDCGANを作成しMNISTのフェイク画像を生成する

Posted at

GANとは

 敵対的生成ネットワークGAN(Generative adversarial networks)とは生成器(Generator)と識別器(Discriminator)の2つのニューラルネットワークからなるシステムです。生成器は識別器を欺くように学習し、識別器は生成器の作った偽物と本物を識別するように学習を進めることで、結果として精度の高い偽物を生み出す事ができるようになります。
 生成器と識別器が相反した目的のもとに学習する様が敵対的と呼ばれる所以のようです。

qiita_20211017_1.png

DCGAN

 DCGAN(Deep Convolutional GAN)はGANで最も基本的なモデルになります。生成器は複数の転置畳み込み層からなり、ノイズを複数の転置畳み込み層で拡大することによりフェイク画像を生成します。識別機は複数の畳み込み層から形成され活性化関数にはLeakyReLUが使われます。

MNIST

 MNISTは0~9の手書き数字の画像データセットで、主に画像認識を目的とした機械学習の初心者向けチュートリアルでよく使われているデータセットです。

mnist.png

実装

 それではDCGANを実装していきます。GPUを使わないコードにしているので、少し時間はかかりますがPythonと必要なライブラリが入っていれば動かせると思います。

ライブラリのインポート

import matplotlib.pyplot as plt
import numpy as np
import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils

定数の設定

n_epoch = 20
batch_size =64
lr = 0.001
z_dim = 100

Generatorクラス

class Generator(nn.Module):
    def __init__(self, z_dim=100, ngf=128, nc=1):
        super().__init__()
        self.convt1 = self.conv_trans_layers(100, 512, 3, 1, 0)
        self.convt2 = self.conv_trans_layers(512, 256, 3, 2, 0)
        self.convt3 = self.conv_trans_layers(256, 128, 4, 2, 1)
        self.convt4 = nn.Sequential(
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1), 
            nn.Tanh()
        )

    @staticmethod
    def conv_trans_layers(in_channels, out_channels, kernel_size, stride, padding):
        net = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False), 
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)    
        )
        return net

    def forward(self, x):
        out = self.convt1(x)
        out = self.convt2(out)
        out = self.convt3(out)
        out = self.convt4(out)
        return out

Discriminatorクラス

class Discriminator(nn.Module):
    def __init__(self, nc=1, ndf=128):
        super().__init__()
        self.conv1 = self.conv_layers(nc, ndf, has_batch_norm=False)
        self.conv2 = self.conv_layers(ndf, 2*ndf)
        self.conv3 = self.conv_layers(2*ndf, 4*ndf, 3, 2, 0)
        self.conv4 = nn.Sequential(
            nn.Conv2d(4*ndf, 1, 3, 1, 0), 
            nn.Sigmoid()           
        )

    @staticmethod
    def conv_layers(in_channels, out_channels, kernel_size=4, stride=2, padding=1, 
                    has_batch_norm=True):
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
            ]
        if has_batch_norm:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        net = nn.Sequential(*layers)
        return net

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        return out

重みの初期化処理

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)

データの前処理

transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize((0.5, ), (0.5, ))
    ])

MNISTの読み込みとデータローダーの設定

dataset = dset.MNIST(root='./data/', download=True, train=True, transform=transform)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

GeneratorとDiscriminatorの生成

netG = Generator(z_dim=z_dim, ngf=128)
netG.apply(weights_init)

netD = Discriminator(nc=1, ndf=128)
netD.apply(weights_init)

損失関数と最適化

criterion = nn.BCELoss()

optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5)
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5)

GeneratorとDiscriminatorの学習とフェイク画像の生成

for epoch in range(n_epoch):
    for i, (real_imgs, labels) in enumerate(tqdm.tqdm(dataloader, position=0)):
        batch_size = real_imgs.size()[0]
        noise = torch.randn(batch_size, z_dim, 1, 1)

        shape = (batch_size, 1, 1, 1)
        labels_real = torch.ones(shape)
        labels_fake = torch.zeros(shape)

        netD.zero_grad()
        output = netD(real_imgs)
        lossD_real = criterion(output, labels_real)

        fake_imgs = netG(noise)
        output = netD(fake_imgs.detach())
        lossD_fake = criterion(output, labels_fake)

        lossD = lossD_real + lossD_fake
        lossD.backward()
        optimizerD.step()

        netG.zero_grad()
        output = netD(fake_imgs)
        lossG = criterion(output, labels_real)
        lossG.backward()
        optimizerG.step()

    grid_imgs = vutils.make_grid(fake_imgs[:24].detach())
    grid_imgs_arr = grid_imgs.numpy()
    plt.imshow(np.transpose(grid_imgs_arr, (1, 2, 0)))
    plt.show()

    vutils.save_image(fake_imgs, './result/{}.jpg'.format(epoch))

出力結果

epoch:0
0.jpg

(途中、省略)

epoch:19
19.jpg

 初回の出力では怪しい数字も多いですが、20回目の出力は8~9割くらいが元のMNISTに近い数字が生成出来ていることがわかります。

4
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
1