LoginSignup
1
2

【PyTorch】GAN(Generative Adversarial Network) 備忘録

Last updated at Posted at 2023-06-28

GAN(GAN, Ian J. Goodfellow, 2014)とは、敵対的生成ネットワークといわれる生成モデルの一つで、教師なし学習の一つである。

生成器(Generator)で、特徴の種に相当する一次元ランダムノイズと正解画像一次元データを入力として、ニューラルネットワークにより特定の画像データを出力し、識別器(Discriminator)で偽物と本物のデータをそれぞれ0,1のテンソルて仮定して、損失関数を計算し最適化する。生成器は、ランダムノイズを識別器に入力し、その出力と正解1のテンソルで損失関数を計算し最適化する。

この2つのネットワークの学習を交互に行い、お互いの損失関数を最適化することで、生成器に本物のデータに近い偽物のデータを生成できるように、識別器に生成器の出力を正しく判別するように学習する。生成器は画像生成や疑似データ生成に使用され、識別器は異常検知などに使用される。

実装

下記、サイトを参考にGANをPyTorchで実装してみた。

使用しているデータはFashion-MNISTで、28x28のグレースケール画像で、10のラベルデータが、60,000(トレーニング)枚+ 10,000(テスト)枚入っている。

Fashion-MNIST1.jpg

Pythonコード全文

PyTorchなど必要なライブラリをインストールできていれば、コピペで実行可能

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optimizers
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# GANモデルのクラス
class GAN(nn.Module):
    def __init__(self, device='cpu'):
        super().__init__()
        self.device = device
        self.G = Generator(device=device)
        self.D = Discriminator(device=device)
    def forward(self, x):
        x = self.G(x)
        y = self.D(x)
        return y
    
# 識別器(Generator)のクラス
class Discriminator(nn.Module):
    def __init__(self, device='cpu'):
        super().__init__()
        self.device = device
        self.conv1 = nn.Conv2d(1, 128,
                               kernel_size=(3, 3),
                               stride=(2, 2),
                               padding=1)
        self.relu1 = nn.LeakyReLU(0.2)
        self.conv2 = nn.Conv2d(128, 256,
                               kernel_size=(3, 3),
                               stride=(2, 2),
                               padding=1)
        self.bn2 = nn.BatchNorm2d(256)
        self.relu2 = nn.LeakyReLU(0.2)
        self.fc = nn.Linear(256*7*7, 1024)
        self.bn3 = nn.BatchNorm1d(1024)
        self.relu3 = nn.LeakyReLU(0.2)
        self.out = nn.Linear(1024, 1)

    def forward(self, x):
        h = self.conv1(x)
        h = self.relu1(h)
        h = self.conv2(h)
        h = self.bn2(h)
        h = self.relu2(h)
        h = h.view(-1, 256*7*7)
        h = self.fc(h)
        h = self.bn3(h)
        h = self.relu3(h)
        h = self.out(h)
        y = torch.sigmoid(h)
        return y

# 生成器(Discriminator)のクラス
class Generator(nn.Module):
    def __init__(self, input_dim=100, device='cpu'):
        super().__init__()
        self.device = device
        self.linear = nn.Linear(input_dim, 256*14*14)
        self.bn1 = nn.BatchNorm1d(256*14*14)
        self.relu1 = nn.ReLU()
        self.conv1 = nn.Conv2d(256, 128,
                               kernel_size=(3, 3),
                               padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.relu2 = nn.ReLU()
        self.conv2 = nn.Conv2d(128, 64,
                               kernel_size=(3, 3),
                               padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.relu3 = nn.ReLU()
        self.conv3 = nn.Conv2d(64, 1, kernel_size=(1, 1))
    
    def forward(self, x):
        h = self.linear(x)
        h = self.bn1(h)
        h = self.relu1(h)
        h = h.view(-1, 256, 14, 14)
        h = F.interpolate(h, size=(28, 28))
        h = self.conv1(h)
        h = self.bn2(h)
        h = self.relu2(h)
        h = self.conv2(h)
        h = self.bn3(h)
        h = self.relu3(h)
        h = self.conv3(h)
        y = torch.sigmoid(h)
        return y
    
#一様乱数のノイズを生成する関数
def gen_noise(batch_size):
    return torch.empty(batch_size, 100).uniform_(0, 1).to(device)
    
if __name__ == '__main__':
    # ランダムシードの固定
    np.random.seed(1234)
    torch.manual_seed(1234)
    # 演算デバイスの設定、GPUが使用可能ならGPUを使用、そうでなければCPUを使用
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    #### 1. データの読み込み
    # Fashion-MNIST: MNISTのファッションデータ
    # 60,000 例のトレーニング セットと 10,000 例のテスト セット
    # 28x28 グレースケール イメージで、10 クラスのラベル
    root = os.path.join(os.path.dirname(__file__), '.', 'data', 'fashion_mnist')
    # 複数の前処理を定義, 画像データをテンソルに変換する, 1次元のベクトルに変換するラムダ関数
    transform = transforms.Compose([transforms.ToTensor(),lambda x: x.view(-1)])
    # Fashion-MNISTのデータを前処理して、ダウンロードする
    mnist_train = torchvision.datasets.FashionMNIST(root=root,
                                                    download=True,
                                                    train=True,
                                                    transform=transform)
    # 訓練データをバッチで取得するためのデータローダ
    # 1つのバッチに含まれるサンプル数が100
    train_dataloader = DataLoader(mnist_train,batch_size=100,shuffle=True)

    #### 2. モデルの構築
    # モデルの定義
    model = GAN(device=device).to(device)

    # 最適化手法の定義
    # Adam(Adaptive moment): 適応的モーメント, モーメンタムとRMSPropを組み合わせて、鞍点に落ち込みにくく、学習率を0にしにくくしている
    # Momentum: 運動量, 勾配に落ち込む際に加速度項を追加して、物理的な加速度を再現して、鞍点に落ち込むのを防ぐ作用
    # AdaGrad(Adaptive Gradient): 適応的勾配, 重みwの学習率の自動調整, k方向の累積の修正量が多い場合はk方向の修正量を減らす作用
    # RMS Prop: Root Mean Square Prop, 重みの移動平均, 過去の情報が指数関数的に薄まって作用
    optimizer_D = optimizers.Adam(model.D.parameters(), lr=0.0002) # learning_rate: 学習率
    optimizer_G = optimizers.Adam(model.G.parameters(), lr=0.0002)

    # 損失関数の定義
    # 交差エントロピー誤差 Binary Cross Entropy Loss {H(p, q) = -\sum_{x} p(x) log(q(x))} 微分計算がしやすい
    criterion = nn.BCELoss()
    def compute_loss(label, preds):
        return criterion(preds, label)

    #### 3. モデルの訓練   
    def train_step(x):
        # 入力データのバッチサイズ
        batch_size = x.size(0)
        # 識別器と生成器の訓練モードを有効にする
        model.D.train()
        model.G.train()
        # 識別器の訓練
        # 784 次元のベクトルを 28x28 の画像に変更
        # x.size() = 100 × 784の行列
        x = x.view(-1, 1, 28, 28)
        # x.size() = 28 × 28の行列, チャンネル数1(グレースケール)
        
        # 入力データに対する推論をして、次元を削減して1次元のテンソルに変換
        preds = model.D(x).squeeze() # 本物画像に対する予測
        # 本物の画像に対する正解ラベルとして、サイズが batch_size の値1のテンソル
        t = torch.ones(batch_size).float().to(device)
        # 予測値 preds と本物正解ラベル t を用いて交差エントロピー誤差を計算
        loss_D_real = compute_loss(t, preds)

        # 偽物画像生成
        noise = gen_noise(batch_size)
        # 生成器モデルにノイズ noise を入力し、生成された偽物の画像
        gen = model.G(noise)
        # 生成器モデルにノイズ noise を入力し、生成された偽物の画像に対する推論, 生成器Gに勾配が伝わらないようにするため,detachしている。
        preds = model.D(gen.detach()).squeeze()
        # 偽物の画像に対する正解ラベルとして、サイズが batch_size の0のテンソル
        t = torch.zeros(batch_size).float().to(device)
        # 予測値 preds と偽物正解ラベル t を用いて交差エントロピー誤差を計算
        loss_D_fake = compute_loss(t, preds)
        
        # 識別器と生成器の損失関数の合算
        loss_D = loss_D_real + loss_D_fake
        
        # モデル内のパラメータの勾配を初期化して、識別機の学習
        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()

        # 生成器の学習
        noise = gen_noise(batch_size)
        gen = model.G(noise)
        preds = model.D(gen).squeeze()
        # 生成器の出力はすべて正解として損失関数を計算
        t = torch.ones(batch_size).float().to(device)
        loss_G = compute_loss(t, preds)
        # モデル内のパラメータの勾配を初期化して、生成器の学習
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        return loss_D, loss_G

    epochs = 10
    for epoch in range(epochs):
        train_loss_D = 0.
        train_loss_G = 0.
        test_loss = 0.
        for (x, _) in train_dataloader:
            x = x.to(device)
            loss_D, loss_G = train_step(x)
            train_loss_D += loss_D.item()
            train_loss_G += loss_G.item()
        train_loss_D /= len(train_dataloader)
        train_loss_G /= len(train_dataloader)
        print('Epoch: {}, D Cost: {:.3f}, G Cost: {:.3f}'.format(epoch+1,train_loss_D,train_loss_G))

    # モデルのパラメータのみを保存
    torch.save(model.state_dict(), 'model_weight.pth')
    # モデル全体を保存
    torch.save(model, 'model.pth')

    ### 4. Test model
    # 保存したモデルの読み込み
    model = GAN(device=device).to(device)
    model.load_state_dict(torch.load('model_weight.pth'))
    # 特定のバッチサイズで一様乱数の画像を生成して、生成器で推論する関数
    def generate(batch_size=16):
        model.eval()
        noise = gen_noise(batch_size)
        gen = model.G(noise)
        return gen
    images = generate(batch_size=16)
    # 次元を削減して1次元のテンソルに変換して、GPUからCPUにデータを移動する
    images = images.squeeze().detach().cpu().numpy()
    # 生成した画像を一つのキャンバスに並べて表示する
    plt.figure(figsize=(6, 6))
    for i, image in enumerate(images):
        plt.subplot(4, 4, i+1)
        plt.imshow(image, cmap='binary_r')
        plt.axis('off')
        plt.tight_layout()
    plt.savefig("GAN-Fashion-MNIST.jpg")
    plt.show()

実行結果

エポック数を変更して、生成画像がどのように変わるのか確認してみた。

epochsを2で生成した画像
GAN-Fashion-MNIST_epochs2.jpg

epochsを5で生成した画像
GAN-Fashion-MNIST_epochs5.jpg

epochsを10で生成した画像
GAN-Fashion-MNIST_epochs10.jpg

epochsを20で生成した画像
GAN-Fashion-MNIST.jpg

epochsを40で生成した画像
GAN-Fashion-MNIST.jpg

PyTorchチュートリアルの実装

下記チュートリアルを参考に実装した。

データ準備

画像データは、102の花のカテゴリーで各クラス 40〜258枚の画像で構成されたOxford 102を使用した。
下記リンクからダウンロードできる。

Pythonコード全文

import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML


# Generator Code
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. ``(ngf*8) x 4 x 4``
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. ``(ngf*4) x 8 x 8``
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. ``(ngf*2) x 16 x 16``
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. ``(ngf) x 32 x 32``
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. ``(nc) x 64 x 64``
        )

    def forward(self, input):
        return self.main(input)
    
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is ``(nc) x 64 x 64``
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf) x 32 x 32``
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*2) x 16 x 16``
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*4) x 8 x 8``
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*8) x 4 x 4``
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)


# 生成器と識別器のモデルパラメーターの初期化
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


if __name__ == '__main__':
    manualSeed = 999
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    torch.use_deterministic_algorithms(True)

    #### 0. ハイパーパラメータ設定
    # Number of workers for dataloader
    workers = 2
    # バッチサイズ
    batch_size = 128
    # カラーチャネル
    nc = 3
    # Size of z latent vector (i.e. size of generator input)
    nz = 100
    # 生成器と識別器のの特徴マップのサイズ
    ngf, ndf = 64, 64
    # エポック数, 学習率
    num_epochs, lr = 10, 0.0002
    # Number of GPUs available. Use 0 for CPU mode.
    ngpu = 1
    # Establish convention for real and fake labels during training
    real_label, fake_label = 1., 0.

    # 1. データ読み込み
    dataroot = "oxford-102/"
    image_size = 64
    dataset = dset.ImageFolder(root=dataroot,
                               transform=transforms.Compose([
                                   transforms.Resize(image_size),
                                   transforms.CenterCrop(image_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               ]))
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                             shuffle=True, num_workers=workers)
    
    #### 2. モデルの構築  
    device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
    # 生成器モデル宣言
    netG = Generator(ngpu).to(device)
    if (device.type == 'cuda') and (ngpu > 1):
        netG = nn.DataParallel(netG, list(range(ngpu)))
    netG.apply(weights_init)
    # 識別器モデル宣言
    netD = Discriminator(ngpu).to(device)
    if (device.type == 'cuda') and (ngpu > 1):
        netD = nn.DataParallel(netD, list(range(ngpu)))
    netD.apply(weights_init)
    # 損失関数の定義
    criterion = nn.BCELoss()
    # 最適化手法の定義
    beta1 = 0.5
    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

    #### 3. モデルの訓練
    print("Starting Training Loop")
    G_losses, D_losses = [], []
    iters = 0
    for epoch in range(num_epochs):
        for i, data in enumerate(dataloader, 0):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            ## Train with all-real batch
            netD.zero_grad()
            # Format batch
            real_cpu = data[0].to(device)
            b_size = real_cpu.size(0)
            label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
            # Forward pass real batch through D
            output = netD(real_cpu).view(-1)
            # Calculate loss on all-real batch
            errD_real = criterion(output, label)
            # Calculate gradients for D in backward pass
            errD_real.backward()
            D_x = output.mean().item()

            ## Train with all-fake batch
            # Generate batch of latent vectors
            noise = torch.randn(b_size, nz, 1, 1, device=device)
            # Generate fake image batch with G
            fake = netG(noise)
            label.fill_(fake_label)
            # Classify all fake batch with D
            output = netD(fake.detach()).view(-1)
            # Calculate D's loss on the all-fake batch
            errD_fake = criterion(output, label)
            # Calculate the gradients for this batch, accumulated (summed) with previous gradients
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            # Compute error of D as sum over the fake and the real batches
            errD = errD_real + errD_fake
            # Update D
            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            # Since we just updated D, perform another forward pass of all-fake batch through D
            output = netD(fake).view(-1)
            # Calculate G's loss based on this output
            errG = criterion(output, label)
            # Calculate gradients for G
            errG.backward()
            D_G_z2 = output.mean().item()
            # Update G
            optimizerG.step()

            # Output training stats
            if i % 50 == 0:
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                      % (epoch, num_epochs, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

            # Save Losses for plotting later
            G_losses.append(errG.item())
            D_losses.append(errD.item())
            iters += 1

    #### 4. モデルの保存
    torch.save(netG.state_dict(), 'model_weight.pth')
    torch.save(netG, 'model.pth')

    #### 5. Test model
    # 保存したモデルの読み込み
    netG = Generator(ngpu).to(device)
    netG.load_state_dict(torch.load('model_weight.pth'))

    fixed_noise = torch.randn(64, nz, 1, 1, device=device)
    with torch.no_grad():
        fake = netG(fixed_noise).detach().cpu()
    #複数の画像から、それらをグリッド上に並べた画像を作成できる
    image = vutils.make_grid(fake, padding=2, normalize=True)

    plt.figure(figsize=(15,15))
    plt.axis("off")
    plt.title("Fake Images")
    plt.tight_layout()
    plt.imshow(np.transpose(image,(1,2,0)))
    plt.savefig("GAN-Flower.jpg")
    plt.show()

実行結果

うっすら花っぽい画像が生成された。
GAN-Flower.jpg

入力画像サイズ変更

64pxを128pxに変更する場合、モデルに一層追加して、画像サイズに合わせる必要がある。

128pxバージョン

import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML


# Generator Code
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 16, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 16),
            nn.ReLU(True),
            # !!!ADDIND!!!
            nn.ConvTranspose2d(ngf * 16, ngf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. ``(ngf*8) x 4 x 4``
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. ``(ngf*4) x 8 x 8``
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. ``(ngf*2) x 16 x 16``
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. ``(ngf) x 32 x 32``
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. ``(nc) x 64 x 64``
        )

    def forward(self, input):
        return self.main(input)
    
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is ``(nc) x 64 x 64``
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf) x 32 x 32``
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*2) x 16 x 16``
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*4) x 8 x 8``
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # !!!ADDING!!!
            nn.Conv2d(ndf * 8, ndf * 16, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 16),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*8) x 2 x 2``
            nn.Conv2d(ndf * 16, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)


# 生成器と識別器のモデルパラメーターの初期化
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


if __name__ == '__main__':
    manualSeed = 999
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    torch.use_deterministic_algorithms(True)

    #### 0. ハイパーパラメータ設定
    # Number of workers for dataloader
    workers = 2
    # バッチサイズ
    batch_size = 128
    # カラーチャネル
    nc = 3
    # Size of z latent vector (i.e. size of generator input)
    nz = 100
    # 生成器と識別器のの特徴マップのサイズ
    ngf, ndf = 64, 64
    # エポック数, 学習率
    num_epochs, lr = 10, 0.0002
    # Number of GPUs available. Use 0 for CPU mode.
    ngpu = 1
    # Establish convention for real and fake labels during training
    real_label, fake_label = 1., 0.

    # 1. データ読み込み
    dataroot = "oxford-102/"
    image_size = 128
    dataset = dset.ImageFolder(root=dataroot,
                               transform=transforms.Compose([
                                   transforms.Resize(image_size),
                                   transforms.CenterCrop(image_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               ]))
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                             shuffle=True, num_workers=workers)
    
    #### 2. モデルの構築  
    device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
    # 生成器モデル宣言
    netG = Generator(ngpu).to(device)
    if (device.type == 'cuda') and (ngpu > 1):
        netG = nn.DataParallel(netG, list(range(ngpu)))
    netG.apply(weights_init)
    # 識別器モデル宣言
    netD = Discriminator(ngpu).to(device)
    if (device.type == 'cuda') and (ngpu > 1):
        netD = nn.DataParallel(netD, list(range(ngpu)))
    netD.apply(weights_init)
    # 損失関数の定義
    criterion = nn.BCELoss()
    # 最適化手法の定義
    beta1 = 0.5
    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

    #### 3. モデルの訓練
    print("Starting Training Loop")
    G_losses, D_losses = [], []
    iters = 0
    for epoch in range(num_epochs):
        for i, data in enumerate(dataloader, 0):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            ## Train with all-real batch
            netD.zero_grad()
            # Format batch
            real_cpu = data[0].to(device)
            b_size = real_cpu.size(0)
            label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
            # Forward pass real batch through D
            output = netD(real_cpu).view(-1)
            # Calculate loss on all-real batch
            errD_real = criterion(output, label)
            # Calculate gradients for D in backward pass
            errD_real.backward()
            D_x = output.mean().item()

            ## Train with all-fake batch
            # Generate batch of latent vectors
            noise = torch.randn(b_size, nz, 1, 1, device=device)
            # Generate fake image batch with G
            fake = netG(noise)
            label.fill_(fake_label)
            # Classify all fake batch with D
            output = netD(fake.detach()).view(-1)
            # Calculate D's loss on the all-fake batch
            errD_fake = criterion(output, label)
            # Calculate the gradients for this batch, accumulated (summed) with previous gradients
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            # Compute error of D as sum over the fake and the real batches
            errD = errD_real + errD_fake
            # Update D
            optimizerD.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            # Since we just updated D, perform another forward pass of all-fake batch through D
            output = netD(fake).view(-1)
            # Calculate G's loss based on this output
            errG = criterion(output, label)
            # Calculate gradients for G
            errG.backward()
            D_G_z2 = output.mean().item()
            # Update G
            optimizerG.step()

            # Output training stats
            if i % 50 == 0:
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                      % (epoch, num_epochs, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

            # Save Losses for plotting later
            G_losses.append(errG.item())
            D_losses.append(errD.item())
            iters += 1

    #### 4. モデルの保存
    torch.save(netG.state_dict(), 'model_weight.pth')
    torch.save(netG, 'model.pth')

    #### 5. Test model
    # 保存したモデルの読み込み
    netG = Generator(ngpu).to(device)
    netG.load_state_dict(torch.load('model_weight.pth'))

    fixed_noise = torch.randn(64, nz, 1, 1, device=device)
    with torch.no_grad():
        fake = netG(fixed_noise).detach().cpu()
    #複数の画像から、それらをグリッド上に並べた画像を作成できる
    image = vutils.make_grid(fake, padding=2, normalize=True)

    plt.figure(figsize=(15,15))
    plt.axis("off")
    plt.title("Fake Images")
    plt.tight_layout()
    plt.imshow(np.transpose(image,(1,2,0)))
    plt.savefig("GAN-Flower.jpg")
    plt.show()

実行結果

先ほどよりはっきりしているが、拡散モデルには程遠い
epoch数 30
GAN-Flower.jpg

epoch数 100
GAN-Flower_epoch_100.jpg

まとめ

今回は、GAN(Generative Adversarial Network)のPyTorchによる実装方法を紹介した。

参考文献

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2