27
33

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PytorchでGANを実装してみた。

Last updated at Posted at 2021-07-22

はじめに

2014年にGoodfellowらによって提案されたGenerative Adversarial Networks(GAN)は、コンピュータビジョンにおける画像生成の領域に革命をもたらした。見事で生き生きとした画像が、実際に機械によって生成されたものであると、誰も信じることができなかったからだ。この記事では、PyTorchによるGANの実装とその学習手順を紹介する。

目次

  1. 敵対生成ネットワーク
  2. 損失関数
  3. Pytorchによる実装
  4. 参考文献
  5. 著者のUdemy講座

#敵対生成ネットワーク

敵対生成ネットワーク(GAN)では、Generator(画像生成器)にランダム性のあるノイズを混入させて画像生成を行ったあとDiscriminator(偽物判別器)で、それが本物か偽物かを判定しながら学習を行う。Generatorは画像が偽物だと見破られないように学習し、Discriminatorは判別の精度をあげるように学習する。この相互作用が高精度に本物のような画像を生成することを可能にするのだ。

GAN.png

#損失関数

ジェネレーターはできるだけ本物に近い画像を生成したい一方で、判別器は生成された画像が偽物であることを識別したい。

D(x)をxが本物の画像である確率とし、G(z)がジェネレータの出力だとする。識別器は二値分類器に似ているので、識別器の目標は関数を最大化することになる一方で、ジェネレータは識別器をだますこと(つまり、識別器の関数を最小化することが目的となる。

このようなMin-Maxゲームは収束が難しいため、学習率などのハイパーパラメタの影響が大きい。


log(D(x)) + log(1-D(G(z)))

#実装

生成器と識別器の定義

datasetsで簡単に手に入るMNIST(0から9の数字60,000枚(28x28ピクセル))を扱うための生成器(Generator)と識別器(Discriminator)の実装をPytorchで行った例を示す。Pytorchを用いると比較的シンプルに定義することができる。

識別器はnn.Moduleを継承したクラスとして定義する。入力は28 * 28=784次元に平らにしたイメージの入力を想定し、隠れ層は512次元の全結合層とする。活性化関数にLeakyReLUを用いて、そのあとはシグモイド関数に入れ二値分類ができるようにしている。

生成器は、ランダムな128次元のノイズを入力し28 x 28ピクセルの画像を生成するように全結合層を3つ利用しており、活性化関数にはReLUを用いている。


import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import transforms

"""
ネットワーク・アーキテクチャー
判別器と生成器のアーキテクチャは以下の通りです。
"""

class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 1)
        self.activation = nn.LeakyReLU(0.1)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.activation(self.fc1(x))
        x = self.fc2(x)
        return nn.Sigmoid()(x)


class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.fc1 = nn.Linear(128, 1024)
        self.fc2 = nn.Linear(1024, 2048)
        self.fc3 = nn.Linear(2048, 784)
        self.activation = nn.ReLU()

    def forward(self, x):
        x = self.activation(self.fc1(x))
        x = self.activation(self.fc2(x))
        x = self.fc3(x)
        x = x.view(-1, 1, 28, 28)
        return nn.Tanh()(x)

訓練


"""
生成敵対ネットワークの作成に必要なライブラリのインポート
コードは主にPyTorchライブラリを使って開発されています
"""
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import transforms
import numpy as np
import matplotlib.pyplot as plt


# GPU利用可否確認
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



# ハイパーパラメタ設定
epochs = 30
lr = 2e-4
batch_size = 64
loss = nn.BCELoss()

# Model
G = generator().to(device)
D = discriminator().to(device)

G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))


"""
画像変換とデータローダの作成
ここでは分類ではなく生成のトレーニングを行っているので
train_loaderのみがロードされます。
"""
# Transform
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])
# Load data
train_set = datasets.MNIST('mnist/', train=True, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)


"""
ネットワークの学習手順
識別器と生成器の損失はステップごとに更新される
判別器は本物と偽物を分類することを目的とする
ジェネレータは可能な限りリアルな画像を生成することを目的とする
"""
for epoch in range(epochs):
    for idx, (imgs, _) in enumerate(train_loader):
        idx += 1

        # 識別器の学習
        # 本物の入力は,MNISTデータセットの実際の画像
        # 偽の入力はジェネレータから
        # 本物の入力は1に、偽物は0に分類されるべきである
        real_inputs = imgs.to(device)
        real_outputs = D(real_inputs)
        real_label = torch.ones(real_inputs.shape[0], 1).to(device)

        noise = (torch.rand(real_inputs.shape[0], 128) - 0.5) / 0.5
        noise = noise.to(device)
        fake_inputs = G(noise)
        fake_outputs = D(fake_inputs)
        fake_label = torch.zeros(fake_inputs.shape[0], 1).to(device)

        outputs = torch.cat((real_outputs, fake_outputs), 0)
        targets = torch.cat((real_label, fake_label), 0)

        D_loss = loss(outputs, targets)
        D_optimizer.zero_grad()
        D_loss.backward()
        D_optimizer.step()


        # Generatorのトレーニング
        # ジェネレータにとっての目標は 識別者に全てが1であると信じさせること
        noise = (torch.rand(real_inputs.shape[0], 128)-0.5)/0.5
        noise = noise.to(device)

        fake_inputs = G(noise)
        fake_outputs = D(fake_inputs)
        fake_targets = torch.ones([fake_inputs.shape[0], 1]).to(device)
        G_loss = loss(fake_outputs, fake_targets)
        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()

        if idx % 100 == 0 or idx == len(train_loader):
            print('Epoch {} Iteration {}: discriminator_loss {:.3f} generator_loss {:.3f}'.format(epoch, idx, D_loss.item(), G_loss.item()))

    if (epoch+1) % 10 == 0:
        torch.save(G, 'Generator_epoch_{}.pth'.format(epoch))
        print('Model saved.')
  • 訓練結果

前略
Epoch 38 Iteration 300: discriminator_loss 0.705 generator_loss 0.710

訓練したGANで画像生成

訓練したGANにランダムなノイズを入力して画像を生成してみる。
人が書いたような8をGANで描くことができた。


for i in input:
  print("real")
  plt.imshow(i[0][0].reshape(28,28))
  plt.show()
  real_inputs = i[0][0]
  noise = (torch.rand(real_inputs.shape[0], 128)-0.5)/0.5
  noise = noise.to(device)
  fake_inputs = G(noise)
  print("fake")
  plt.imshow(fake_inputs[0][0].cpu().detach().numpy().reshape(28,28))
  plt.show()
  break
  • 結果

Real

image.png

Fake

image.png


意外にもシンプルな実装でGANを検証できた。明日は東京オリンピックだ。

参考文献

#講座

本記事の作者のUdemy講座を以下にて公開しています。Pytorchの実装を本格的に勉強したい方はハンズオンをご受講下さい。

27
33
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
27
33

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?