3
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

【PyTorchチュートリアル⑪】DCGAN Tutorial

Posted at

はじめに

前回に引き続き、PyTorch 公式チュートリアル の第11弾です。
今回は DCGAN Tutorial を進めます。

DCGAN Tutorial

Introduction

このチュートリアルでは、DCGANを紹介します。実在の有名人の画像をもとに新しい有名人の画像を生成する敵対的生成ネットワーク(GAN)をトレーニングします。
このチュートリアルのコードは、pytorch/examples にある dcgan の実装を利用しています。GANの事前知識は必要ありません。処理時間の都合上、GPUを利用したほうがよいでしょう。
それでは始めてみましょう。

Generative Adversarial Networks

What is a GAN?

GANは、ニューラルネットワークのモデルにトレーニングデータの分布を学習させるためのフレームワークで、学習した分布から新しいデータを生成できます。
GANは、2014年にイアン・グッドフェロー氏によって考え出され、論文 Generative Adversarial Nets で最初に発表されました。
GANは、**生成器(generator)識別器(discriminator)**の2つのモデルで構成されています。
生成器は、トレーニング画像を元に「偽の」画像を生成します。
一方、識別器は、画像が実際のトレーニング画像であるか、生成器からの偽の画像であるかを判別します。
生成器は識別器が判断できないような良い偽物を生成することで学習し、識別器は実際の画像と偽の画像を正しく分類するように学習します。
この学習のゴールは、生成器がトレーニングデータのように見える完全な偽物を生成しているときであり、識別器は、生成器の出力が本物か偽物かを常に50%の信頼度で推測します。

それでは、識別器から見ていきましょう。
$x$ を画像を表すデータとします。
$D(x)$ は、$x$が本物の画像である確率(スカラー値)を出力する識別器です。ここでは、画像を扱っているため、$D(x)$ への入力はCHW(色、縦、横)で、このチュートリアルでは 3x64x64 の画像です。
$D(x)$ のスカラー値は $x$ が本物の画像の場合は大きい値になり、生成器が作製した画像の場合は小さい値になります。$D(x)$ は通常の 2クラス分類器と考えることもできます。

次に生成器を見ていきます。$z$ を標準正規分布からサンプリングされた潜在空間ベクトルとします。 $G(z)$ は、潜在ベクトル $z$ をデータ空間にマッピングするジェネレーター関数を表します。 $G$ の目標は、トレーニングデータの取得元の分布 ($p_{data}$) を推定して、推定された分布 ($p_g$) から偽のサンプルを生成できるようにすることです。

(ここでいう潜在ベクトルは、n次元の標準化されたランダム値のことです。このチュートリアルでは 100次元としています。簡単のため2次元で考えると、点(1,1)を入力すると1つの画像が作成され、点(3,2)を入力すると別の画像が作成されます。以下の図のように入力の点を徐々に変化させると、生成される画像も徐々に変化するイメージです。※実際には標準化されているため(1,1)や(3,2)の点はありません)

ダウンロード.png

$D(G(z))$ は、ジェネレータ $G$ の出力が実像である確率(スカラー)です。 グッドフェロー氏の論文で説明されているように、$D$ と $G$ は、$D$ が実数と偽物を正しく分類する確率($logD(x)$)を最大化しようとし、$G$ が出力が偽物であると予測する確率( $log(1-D(G(x)))$ )を最小化しようとするミニマックスゲームをプレイします。論文から、GAN損失関数は以下で定義されます。

$\underset{G}{\text{min}} \underset{D}{\text{max}}V(D,G) = \mathbb{E}_{x\sim p_{data}(x)}\big[logD(x)\big] + \mathbb{E}_{z\sim p_{z}(z)}\big[log(1-D(G(z)))\big] $

理論的には、このミニマックスゲームの解は、$p_g = p_{data}$ であり、識別器は入力が本物か偽物かをランダムに推測します。ただし、GANの収束理論はまだ活発に研究されており、実際にはモデルが常にこの時点までトレーニングできるとは限りません。

What is a DCGAN?

DCGAN(Deep Convolutional GAN) は、上記のGANを拡張したものですが、識別器と生成器でそれぞれ畳み込み層と畳み込み転置層を使用する点が異なります。
これは、Radford氏らによって、論文「Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks」で最初に説明されました。
識別器は、strided convolution 層(畳み込み層の一種)、Batch Normalization 層、および LeakyReLU activation(活性化関数) で構成されています。入力は 3x64x64 の画像で、出力は入力データ(画像)が実在の写真である確率(スカラー値)です。
生成器は、convolutional-transpose 層、Batch Normalization 層、および ReLU activation(活性化関数) で構成されます。入力は、標準正規分布から抽出された潜在ベクトル $z$で、出力は 3x64x64 のRGB画像です。
論文では、オプティマイザーのセットアップ方法、損失関数の計算方法、モデルの重みの初期化方法に関するヒントもいくつか示しています。これらはすべて、次のセクションで説明します。

from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# 再現性のためにランダムシードを設定する
manualSeed = 999
#manualSeed = random.randint(1、10000)#新しい結果が必要な場合に使用
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

Inputs

入力を定義します:

  • dataroot - データセットフォルダのルートへのパス。次のセクションでデータセットについて詳しく説明します。
  • workers - DataLoaderでデータをロードするためのワーカースレッドの数。
  • batch_size - トレーニングで使用されるバッチサイズ。 DCGANの論文では128のバッチサイズを使用しています。
  • image_size - トレーニングに使用される画像の空間サイズ。この実装のデフォルトは64x64です。別のサイズが必要な場合は、DとGの構造を変更する必要があります。詳細はこちらをご覧ください。
  • nc - 入力画像のカラーチャンネルの数。カラー画像の場合は 3 です。
  • nz - 潜在ベクトルの長さ。
  • ngf - 生成器を介して伝播される特徴マップの深さに関連します。
  • ndf - 識別器を介して伝播される特徴マップの深さを設定します。
  • num_epochs - 実行するトレーニングエポックの数。エポック数を大きくすると、より良い結果につながりますが、時間がかかります。
  • lr - トレーニングの学習率。 DCGANの論文で説明されているように、この数値は0.0002である必要があります。
  • beta1 - Adamオプティマイザーのbeta1ハイパーパラメーター。論文で説明されているように、この数値は0.5である必要があります。
  • ngpu - 使用可能なGPUの数。0の場合、CPUモードで実行されます。この数が0より大きい場合、その数のGPUで実行されます。
# データセットのルートディレクトリ
dataroot = "data/celeba"

# データローダーのワーカー数
workers = 2

# トレーニングのバッチサイズ
batch_size = 128

# トレーニング画像の空間サイズ。
# すべての画像はトランスフォーマーを使用してこのサイズに変更されます。
image_size = 64

# トレーニング画像のチャネル数。カラー画像の場合は「3」 
nc = 3

# 潜在ベクトル z のサイズ(つまり、ジェネレータ入力のサイズ)
nz = 100

# 生成器の feature map のサイズ
ngf = 64

# 識別器の feature map のサイズ
ndf = 64

# エポック数
num_epochs = 5

# 学習率
lr = 0.0002

# Adam オプティマイザのBeta1ハイパーパラメータ
beta1 = 0.5

# 使用可能なGPUの数。0の場合、CPUモードで実行されます
ngpu = 1

Data

このチュートリアルでは、リンク先のサイト またはGoogleドライブからダウンロードできる Celeb-A Faces dataset を使用します。データセットは、img_align_celeba.zipという名前のファイルとしてダウンロードできます。ダウンロードしたら、celebaという名前のディレクトリを作成し、zipファイルをそのディレクトリに解凍します。次に、このノートブックの dataroot にcelebaディレクトリを設定します。
ディレクトリ構造は次のようになります。

   /path/to/celeba
       -> img_align_celeba  
           -> 188242.jpg
           -> 173822.jpg
           -> 284702.jpg
           -> 537394.jpg
              ...

ここでは、以下の処理を行います。

  1. データセットを作成する
  2. データローダーを作成する
  3. 実行するデバイスを設定する
  4. 最後にトレーニングデータの一部を視覚化する。

データセットのルートフォルダにはサブディレクトリが必要です。

# 画像フォルダデータセットは、以下で設定した方法で使用できます。

# データセットを作成する
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# データローダーを作成する
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# 実行するデバイスを決定する
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# トレーニング画像をプロットする
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

ダウンロード1.png

Implementation(実装)

データセットを準備したら、実装を行います。
重みの初期化から始めて、生成器、識別器、損失関数、トレーニングについて詳しく説明します。

Weight Initialization(重みの初期化)

DCGANの論文では、モデルの重みを平均 0、標準偏差 0.02 の正規分布からランダムに初期化することが明示されています。
weights_init関数は、初期化されたモデルを入力として受け取り、すべての畳み込み層、および Batch Normalization 層を再初期化します。
この関数は、初期化直後のモデルに適用されます。

# G(生成器)とD(識別器)の重みの初期化
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

Generator(生成器)

ジェネレータ $G$は、潜在空間ベクトル($z$)をデータ空間にマッピングするように設計されています。データは画像のため、$z$ をデータ空間に変換すると、最終的にトレーニング画像と同じサイズ(3x64x64)のRGB画像が作成されます。
実際には、BatchNorm2d レイヤーおよび ReLU アクティベーションとペアになっている、2次元転置畳み込み層(ConvTranspose2d)によって実現されます。
ジェネレータの出力は、tanh関数で供給され、$[-1,1]$の入力データ範囲に戻ります。
DCGANの論文のジェネレーター画像を以下に示します。

https://pytorch.org/tutorials/_images/dcgan_generator.png

入力セクションで設定した入力(nzngf、およびnc)がコードのジェネレーターアーキテクチャにどのように影響するかに注意してください。 nzは$z$入力ベクトルの長さ、ngfはジェネレータを介して伝播される特徴マップのサイズに関連し、ncは出力画像のチャネル数です(RGB画像の場合は3に設定されます)。以下はジェネレーターのコードです。

# Generator (生成器)

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # 入力は Z で、畳み込み層に渡されます
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # サイズ (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # サイズ  (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # サイズ (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # サイズ (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # サイズ (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

ジェネレータ(生成器)をインスタンス化し、weights_init関数を適用します。
出力されたモデルをチェックし、ジェネレータがどのように構成されているか確認してください。

# ジェネレーターを作成します
netG = Generator(ngpu).to(device)

# 必要に応じてGPUを使用します
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# weights_init関数を適用して、すべての重みを平均「0」、標準偏差「0.02」でランダムに初期化します。
netG.apply(weights_init)

# モデルを出力します
print(netG)
out
Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

Discriminator(識別器)

前述のように、識別器 $D$は、画像を入力として受け取り、入力画像が(偽物ではなく)本物であるという確率を出力する二値分類ネットワークです。
$D$ は3x64x64の入力画像を取得し、それを一連のConv2d、BatchNorm2d、およびLeakyReLU 層で処理し、Sigmoid関数を介して確率を出力します。 このアーキテクチャは、必要に応じて、より多くのレイヤーに拡張できますが、Strided Convolution、BatchNorm、および LeakyReLU を使用することが重要です。
DCGANの論文では、ダウンサンプリングにプーリングではなく、Strided Convolution を使用することを勧めています。また、batch norm と leaky relu も、$G$ と $D$ の両方の学習プロセスと勾配の計算に重要です。

以下は識別器のコードです。

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # 入力は (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # サイズ (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # サイズ (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # サイズ (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # サイズ (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

これで、生成器と同様に、識別器を作成し、weights_init関数を適用して、モデルの構造を出力できます。

# Create the Discriminator
# 識別器を作成します
netD = Discriminator(ngpu).to(device)

# Handle multi-gpu if desired
# 必要に応じてGPUを使用します
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))
    
# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
# weights_init関数を適用して、すべての重みを平均「0」、標準偏差「0.02」でランダムに初期化します。
netD.apply(weights_init)

# Print the model
# モデルを出力します
print(netD)
out
Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)

Loss Functions and Optimizers

損失関数とオプティマイザで、生成器($G$)、識別器($D$)の学習方法を指定できます。
損失関数には PyTorch で次のように定義されているバイナリクロスエントロピー損失関数BCELossを使用します。

$\qquad \ell(x, y) = L = {l_1,\dots,l_N}^\top, \quad l_n = - \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right]$

この損失関数の対数成分 $log(D(x))$ と $log(1-D(G(z)))$ がどのような計算になるか注意してください。損失関数に渡す $y$ の値は呼び出し元で指定できます。$y$ (正解ラベル)に何を指定するかで計算するコンポーネントを選択できることが重要です。
($y_n$ に正解ラベル 1 を指定すると誤差関数の第1項目 $\log x_n$ が残り、$y_n$ に不正解ラベル 0 を指定すると第2項目 $\log (1 - x_n)$ が残ります。)
次に、実際のラベルを1、偽のラベルを0と定義します。これは、$D$ と $G$ の損失を計算するときに使用されます。
最後に、$D$ と$G$ の2つのオプティマイザを設定します。 DCGANの論文で指定されているように、どちらも学習率 0.0002 および Beta1 = 0.5 の Adam オプティマイザです。
ジェネレータの学習の進行状況を追跡するために、標準正規分布からランダムに抽出された潜在ベクトルの固定値 fixed_noise を生成します。トレーニングループでは、このfixed_noiseを定期的に $G$ に入力し、ループ中に固定値 fixed_noise から画像が形成されるのを確認します。

# BCELoss関数を初期化します
criterion = nn.BCELoss()

# ジェネレータの進行を視覚化するために使用する潜在ベクトルを作成します
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# トレーニング中に本物のラベルと偽のラベルのルールを設定します
real_label = 1.
fake_label = 0.

# G と D に Adam オプティマイザを設定する
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

Training

GANフレームワークのすべてを定義したので、トレーニングを実行できます。
ただし、パラメータの設定が正しくないとうまく学習できず、何が悪かったのかもほとんど分かりません。
ここでは、ganhacks に示されているいくつかのベストプラクティスを順守しながら、Goodfellowの論文のアルゴリズム1に従い、トレーニングします。
つまり、「実際の画像と偽の画像に異なるミニバッチを作成」し、「$logD(G(z))$ を最大化する」ように $G$ の目的関数を調整します。
トレーニングは2つの主要な部分に分かれています。パート1は識別器のパラメータを更新し、パート2は生成器を更新します。

Part 1 - Train the Discriminator

識別器をトレーニングする目的は、入力画像を本物か偽物を正しく分類する確率を最大化することです。
グッドフェロー氏の論文では、「確率的勾配を上昇させることによって識別器を更新する」ことが記述されています。実際には、$log(D(x)) + log(1-D(G(z)))$ を最大化する必要があります。
ganhacksとは異なるやり方ですが、以下の2つのステップで実行します。
1つ目のステップは、トレーニングデータセットから実在の画像のバッチを作成し、
$D$ を順伝播し、損失($log(D(x))$)を計算してから逆伝播の勾配を計算します。
2つ目のステップは、生成器を使用して偽の画像のバッチを作成し、$D$ を順伝播、損失($log(1-D(G(z)))$)を計算してから逆伝播の勾配を計算します。
2つのステップの両方で蓄積された勾配を使用して、識別器のオプティマイザの step を実行します。

Part 2 - Train the Generator

先の論文で述べているように、より良い偽物を生成するために、$log(1-D(G(z)))$ を最小化することによって生成器をトレーニングしたいと思います。
しかしながら、トレーニングの初期段階でうまく学習されないことがグッドフェロー氏によって示されています。
代わりに $log(D(G(z)))$ を最大化するように修正します。
以下のコードでは、パート1の生成器の出力を識別器で分類し、実際のラベルを正解ラベルとして $G$ の損失を計算し、逆伝播で $G$ の勾配を計算、最後にオプティマイザで$G$のパラメーターを更新することでこれを実現します。
損失関数の正解ラベルとして実際のラベルを使用することは直感に反するように思われるかもしれませんが、これにより、BCELossの $log(x)$ 部分($log(1-x)$部分ではなく)を使用できるようになります。

最後に、いくつかの統計レポートを作成しつつ、各エポックの終わりに fixed_noise をもとに生成器で画像を生成し、Gのトレーニングの進行状況を確認します。
報告されるトレーニング統計は次のとおりです。

  • Loss_D - 識別器の損失。実在の画像を識別した時の損失 $log(D(x))$ と偽の画像を識別した時の損失 $log(1-D(G(z)))$ の合計値($log(D(x)) + log(1-D(G(z)))$)。(※ チュートリアル原文の式が誤っている気がしますがどうでしょうか。)
  • Loss_G - $log(D(G(z)))$ として計算される生成器の損失
  • D(x) - 実際の画像の識別器の(バッチ全体の)平均。これは1近くから始まり、Gが良くなると理論的には0.5に収束するはずです。
  • D(G(z)) - 偽の画像の識別器の平均。最初の番号は$D$が更新される前であり、2番目の番号は$D$が更新された後です。これらの数値は0付近から始まり、$G$が良くなるにつれて0.5に収束するはずです。

注意 実行するエポックの数やデータセットからデータを削除したかどうかによっては、この手順に時間がかかる場合があります。

※Colaboratory の環境では入力データ(画像)をアップロードしきれず、実行できませんでした。自分のPC環境(古いPCでGPUはGTX 1050Ti です)だと、画像約20万枚、エポック数 5 で1時間ちょっとかかりました。

# トレーニングループ

# 進捗状況を追跡するためのリスト
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# エポックごとのループ
for epoch in range(num_epochs):
    # データローダーのバッチごとのループ
    for i, data in enumerate(dataloader, 0):
        
        ############################
        # (1) Dネットワークの更新:log(D(x)) + log(1 - D(G(z))) を最大化します
        ###########################
        ## 実在の画像でトレーニングします
        netD.zero_grad()
        # バッチのフォーマット
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # 実在の写真で D の順伝播させます
        output = netD(real_cpu).view(-1)
        # 損失を計算します
        errD_real = criterion(output, label)
        # 逆伝播でDの勾配を計算します
        errD_real.backward()
        D_x = output.mean().item()

        ## 偽の画像でトレーニングします
        # 潜在ベクトルのバッチを生成します
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Gで偽の画像を生成します
        fake = netG(noise)
        label.fill_(fake_label)
        # 生成した偽画像をDで分類します
        output = netD(fake.detach()).view(-1)
        # Dの損失を計算します
        errD_fake = criterion(output, label)
        # 勾配を計算します
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # 実在の画像の勾配と偽画像の勾配を加算します
        errD = errD_real + errD_fake
        # Dを更新します
        optimizerD.step()

        ############################
        # (2) Gネットワ​​ークの更新:log(D(G(z))) を最大化します
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # 偽のラベルは生成器の損失にとって本物です
        # パラメータ更新後のDを利用して、偽画像を順伝播させます
        output = netD(fake).view(-1)
        # この出力に基づいてGの損失を計算します
        errG = criterion(output, label)
        # Gの勾配を計算します
        errG.backward()
        D_G_z2 = output.mean().item()
        # Gを更新します
        optimizerG.step()
        
        # トレーニング統計を出力します
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        
        # 後でプロットするために損失を保存します
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        
        # fixed_noiseによる G の出力を保存し、生成器の精度を確認します
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            
        iters += 1
out
Starting Training Loop...
[0/5][0/1583]    Loss_D: 1.7825    Loss_G: 5.0967    D(x): 0.5566    D(G(z)): 0.5963 / 0.0094
[0/5][50/1583]    Loss_D: 1.2308    Loss_G: 21.3185    D(x): 0.9840    D(G(z)): 0.5805 / 0.0000
[0/5][100/1583]    Loss_D: 0.2915    Loss_G: 6.6904    D(x): 0.9769    D(G(z)): 0.1790 / 0.0028
[0/5][150/1583]    Loss_D: 0.3132    Loss_G: 5.4999    D(x): 0.8093    D(G(z)): 0.0210 / 0.0094
[0/5][200/1583]    Loss_D: 1.8818    Loss_G: 10.8477    D(x): 0.4127    D(G(z)): 0.0022 / 0.0021
[0/5][250/1583]    Loss_D: 0.3169    Loss_G: 3.5397    D(x): 0.8246    D(G(z)): 0.0601 / 0.0547
[0/5][300/1583]    Loss_D: 0.3551    Loss_G: 3.1239    D(x): 0.9287    D(G(z)): 0.2033 / 0.0873
[0/5][350/1583]    Loss_D: 0.3831    Loss_G: 3.4443    D(x): 0.8077    D(G(z)): 0.1026 / 0.0508
[0/5][400/1583]    Loss_D: 0.4064    Loss_G: 4.7648    D(x): 0.8264    D(G(z)): 0.1073 / 0.0158
[0/5][450/1583]    Loss_D: 0.5639    Loss_G: 6.6930    D(x): 0.9511    D(G(z)): 0.3329 / 0.0031
[0/5][500/1583]    Loss_D: 1.3354    Loss_G: 4.8289    D(x): 0.9831    D(G(z)): 0.6628 / 0.0210
[0/5][550/1583]    Loss_D: 0.4430    Loss_G: 5.3935    D(x): 0.8112    D(G(z)): 0.1430 / 0.0082
[0/5][600/1583]    Loss_D: 0.4703    Loss_G: 3.0324    D(x): 0.7795    D(G(z)): 0.1257 / 0.0754
[0/5][650/1583]    Loss_D: 1.7328    Loss_G: 3.3086    D(x): 0.3306    D(G(z)): 0.0012 / 0.1422
[0/5][700/1583]    Loss_D: 0.3655    Loss_G: 5.8728    D(x): 0.9177    D(G(z)): 0.2211 / 0.0053
[0/5][750/1583]    Loss_D: 0.2753    Loss_G: 5.8153    D(x): 0.9107    D(G(z)): 0.1408 / 0.0061
[0/5][800/1583]    Loss_D: 0.6434    Loss_G: 3.9464    D(x): 0.7840    D(G(z)): 0.2572 / 0.0382
[0/5][850/1583]    Loss_D: 0.4208    Loss_G: 3.9550    D(x): 0.8904    D(G(z)): 0.1904 / 0.0348
[0/5][900/1583]    Loss_D: 0.1478    Loss_G: 4.9144    D(x): 0.9053    D(G(z)): 0.0304 / 0.0184
[0/5][950/1583]    Loss_D: 0.4626    Loss_G: 3.9822    D(x): 0.7520    D(G(z)): 0.0995 / 0.0307
[0/5][1000/1583]    Loss_D: 1.3251    Loss_G: 6.6428    D(x): 0.8265    D(G(z)): 0.5140 / 0.0039
[0/5][1050/1583]    Loss_D: 0.2756    Loss_G: 4.9520    D(x): 0.9536    D(G(z)): 0.1760 / 0.0161
[0/5][1100/1583]    Loss_D: 0.4271    Loss_G: 4.3152    D(x): 0.8585    D(G(z)): 0.1801 / 0.0276
[0/5][1150/1583]    Loss_D: 0.4866    Loss_G: 5.8982    D(x): 0.9237    D(G(z)): 0.2801 / 0.0061
[0/5][1200/1583]    Loss_D: 0.5454    Loss_G: 3.4798    D(x): 0.7154    D(G(z)): 0.0932 / 0.0484
[0/5][1250/1583]    Loss_D: 0.7009    Loss_G: 3.6421    D(x): 0.7268    D(G(z)): 0.2214 / 0.0448
[0/5][1300/1583]    Loss_D: 0.9729    Loss_G: 6.5268    D(x): 0.9301    D(G(z)): 0.4758 / 0.0049
[0/5][1350/1583]    Loss_D: 0.2977    Loss_G: 4.4779    D(x): 0.8450    D(G(z)): 0.0783 / 0.0217
[0/5][1400/1583]    Loss_D: 0.4291    Loss_G: 4.7027    D(x): 0.7973    D(G(z)): 0.1081 / 0.0232
[0/5][1450/1583]    Loss_D: 0.5062    Loss_G: 1.7627    D(x): 0.7334    D(G(z)): 0.0914 / 0.2356
[0/5][1500/1583]    Loss_D: 0.3683    Loss_G: 2.8416    D(x): 0.7972    D(G(z)): 0.0877 / 0.0884
[0/5][1550/1583]    Loss_D: 1.2416    Loss_G: 6.1441    D(x): 0.9658    D(G(z)): 0.6089 / 0.0039
[1/5][0/1583]    Loss_D: 0.4863    Loss_G: 4.1000    D(x): 0.8147    D(G(z)): 0.1864 / 0.0311
[1/5][50/1583]    Loss_D: 0.5487    Loss_G: 4.3008    D(x): 0.8068    D(G(z)): 0.2321 / 0.0231
[1/5][100/1583]    Loss_D: 0.3973    Loss_G: 3.1769    D(x): 0.8267    D(G(z)): 0.1488 / 0.0607
[1/5][150/1583]    Loss_D: 0.4199    Loss_G: 3.7225    D(x): 0.8362    D(G(z)): 0.1711 / 0.0388
[1/5][200/1583]    Loss_D: 0.3704    Loss_G: 3.1646    D(x): 0.7950    D(G(z)): 0.0742 / 0.0690
[1/5][250/1583]    Loss_D: 0.4835    Loss_G: 2.6290    D(x): 0.7500    D(G(z)): 0.0945 / 0.1090
[1/5][300/1583]    Loss_D: 1.7134    Loss_G: 0.8583    D(x): 0.3044    D(G(z)): 0.0100 / 0.5134
[1/5][350/1583]    Loss_D: 1.2287    Loss_G: 1.1995    D(x): 0.4230    D(G(z)): 0.0291 / 0.4012
[1/5][400/1583]    Loss_D: 0.3184    Loss_G: 3.1359    D(x): 0.8466    D(G(z)): 0.1128 / 0.0672
[1/5][450/1583]    Loss_D: 0.4920    Loss_G: 2.9625    D(x): 0.7946    D(G(z)): 0.1426 / 0.0824
[1/5][500/1583]    Loss_D: 0.4525    Loss_G: 3.3157    D(x): 0.8559    D(G(z)): 0.2031 / 0.0599
[1/5][550/1583]    Loss_D: 0.3221    Loss_G: 3.4189    D(x): 0.8521    D(G(z)): 0.1203 / 0.0591
[1/5][600/1583]    Loss_D: 0.5004    Loss_G: 3.6299    D(x): 0.8184    D(G(z)): 0.2133 / 0.0486
[1/5][650/1583]    Loss_D: 0.4997    Loss_G: 4.4381    D(x): 0.8822    D(G(z)): 0.2777 / 0.0170
[1/5][700/1583]    Loss_D: 0.3672    Loss_G: 2.6801    D(x): 0.8161    D(G(z)): 0.1072 / 0.1065
[1/5][750/1583]    Loss_D: 0.4129    Loss_G: 2.7585    D(x): 0.8070    D(G(z)): 0.1383 / 0.0915
[1/5][800/1583]    Loss_D: 1.3243    Loss_G: 6.4724    D(x): 0.9383    D(G(z)): 0.6237 / 0.0046
[1/5][850/1583]    Loss_D: 0.4103    Loss_G: 3.9704    D(x): 0.8653    D(G(z)): 0.2067 / 0.0287
[1/5][900/1583]    Loss_D: 0.5130    Loss_G: 2.8125    D(x): 0.7487    D(G(z)): 0.1296 / 0.0867
[1/5][950/1583]    Loss_D: 0.3896    Loss_G: 3.7857    D(x): 0.8697    D(G(z)): 0.1867 / 0.0340
[1/5][1000/1583]    Loss_D: 0.3782    Loss_G: 4.0259    D(x): 0.8248    D(G(z)): 0.1319 / 0.0324
[1/5][1050/1583]    Loss_D: 0.3561    Loss_G: 3.6737    D(x): 0.8014    D(G(z)): 0.0939 / 0.0508
[1/5][1100/1583]    Loss_D: 1.1566    Loss_G: 6.8887    D(x): 0.9395    D(G(z)): 0.5788 / 0.0030
[1/5][1150/1583]    Loss_D: 0.4082    Loss_G: 2.7026    D(x): 0.8282    D(G(z)): 0.1693 / 0.0954
[1/5][1200/1583]    Loss_D: 0.4618    Loss_G: 2.0251    D(x): 0.7042    D(G(z)): 0.0517 / 0.1746
[1/5][1250/1583]    Loss_D: 0.4982    Loss_G: 3.7400    D(x): 0.8717    D(G(z)): 0.2661 / 0.0350
[1/5][1300/1583]    Loss_D: 0.8765    Loss_G: 2.0860    D(x): 0.5928    D(G(z)): 0.1493 / 0.1935
[1/5][1350/1583]    Loss_D: 0.4507    Loss_G: 2.7327    D(x): 0.8017    D(G(z)): 0.1732 / 0.0891
[1/5][1400/1583]    Loss_D: 0.6554    Loss_G: 4.4541    D(x): 0.9331    D(G(z)): 0.3904 / 0.0190
[1/5][1450/1583]    Loss_D: 0.3707    Loss_G: 3.0559    D(x): 0.8092    D(G(z)): 0.1124 / 0.0752
[1/5][1500/1583]    Loss_D: 0.3978    Loss_G: 2.9060    D(x): 0.8130    D(G(z)): 0.1448 / 0.0722
[1/5][1550/1583]    Loss_D: 0.6422    Loss_G: 3.3732    D(x): 0.7391    D(G(z)): 0.2337 / 0.0562
[2/5][0/1583]    Loss_D: 0.5571    Loss_G: 4.3724    D(x): 0.8466    D(G(z)): 0.2813 / 0.0222
[2/5][50/1583]    Loss_D: 0.9171    Loss_G: 3.3607    D(x): 0.5004    D(G(z)): 0.0169 / 0.0584
[2/5][100/1583]    Loss_D: 0.5236    Loss_G: 2.8583    D(x): 0.8345    D(G(z)): 0.2470 / 0.0811
[2/5][150/1583]    Loss_D: 0.6476    Loss_G: 2.4148    D(x): 0.6520    D(G(z)): 0.0886 / 0.1282
[2/5][200/1583]    Loss_D: 0.4759    Loss_G: 2.1868    D(x): 0.7868    D(G(z)): 0.1683 / 0.1432
[2/5][250/1583]    Loss_D: 0.4453    Loss_G: 3.5857    D(x): 0.8142    D(G(z)): 0.1816 / 0.0417
[2/5][300/1583]    Loss_D: 0.6185    Loss_G: 1.9197    D(x): 0.6361    D(G(z)): 0.0578 / 0.1909
[2/5][350/1583]    Loss_D: 0.5656    Loss_G: 1.8746    D(x): 0.6823    D(G(z)): 0.0907 / 0.2027
[2/5][400/1583]    Loss_D: 0.5084    Loss_G: 3.8005    D(x): 0.8927    D(G(z)): 0.2975 / 0.0310
[2/5][450/1583]    Loss_D: 0.6327    Loss_G: 3.6768    D(x): 0.8792    D(G(z)): 0.3508 / 0.0370
[2/5][500/1583]    Loss_D: 1.0236    Loss_G: 0.9627    D(x): 0.5368    D(G(z)): 0.2148 / 0.4402
[2/5][550/1583]    Loss_D: 0.5240    Loss_G: 1.8640    D(x): 0.6880    D(G(z)): 0.0900 / 0.1949
[2/5][600/1583]    Loss_D: 1.4416    Loss_G: 0.7844    D(x): 0.3551    D(G(z)): 0.0359 / 0.5423
[2/5][650/1583]    Loss_D: 0.7049    Loss_G: 2.6899    D(x): 0.7614    D(G(z)): 0.3058 / 0.0898
[2/5][700/1583]    Loss_D: 0.6580    Loss_G: 2.1950    D(x): 0.6655    D(G(z)): 0.1447 / 0.1515
[2/5][750/1583]    Loss_D: 0.4866    Loss_G: 2.9747    D(x): 0.8271    D(G(z)): 0.2238 / 0.0660
[2/5][800/1583]    Loss_D: 0.8614    Loss_G: 3.2034    D(x): 0.8388    D(G(z)): 0.4398 / 0.0602
[2/5][850/1583]    Loss_D: 0.5973    Loss_G: 1.9877    D(x): 0.7525    D(G(z)): 0.2234 / 0.1688
[2/5][900/1583]    Loss_D: 0.6860    Loss_G: 1.3544    D(x): 0.6354    D(G(z)): 0.1454 / 0.3062
[2/5][950/1583]    Loss_D: 0.5878    Loss_G: 2.2449    D(x): 0.7269    D(G(z)): 0.1780 / 0.1410
[2/5][1000/1583]    Loss_D: 1.0009    Loss_G: 1.7307    D(x): 0.4573    D(G(z)): 0.0505 / 0.2349
[2/5][1050/1583]    Loss_D: 0.7103    Loss_G: 1.6821    D(x): 0.6081    D(G(z)): 0.1171 / 0.2361
[2/5][1100/1583]    Loss_D: 1.2062    Loss_G: 1.9829    D(x): 0.5149    D(G(z)): 0.2336 / 0.1943
[2/5][1150/1583]    Loss_D: 0.5601    Loss_G: 3.1371    D(x): 0.8733    D(G(z)): 0.3131 / 0.0597
[2/5][1200/1583]    Loss_D: 0.4691    Loss_G: 2.6874    D(x): 0.8428    D(G(z)): 0.2337 / 0.0879
[2/5][1250/1583]    Loss_D: 0.8933    Loss_G: 3.6654    D(x): 0.8084    D(G(z)): 0.4449 / 0.0375
[2/5][1300/1583]    Loss_D: 1.1260    Loss_G: 1.2977    D(x): 0.4179    D(G(z)): 0.0427 / 0.3397
[2/5][1350/1583]    Loss_D: 0.5914    Loss_G: 2.9971    D(x): 0.8794    D(G(z)): 0.3353 / 0.0661
[2/5][1400/1583]    Loss_D: 0.5957    Loss_G: 1.9661    D(x): 0.6747    D(G(z)): 0.1309 / 0.1766
[2/5][1450/1583]    Loss_D: 0.6327    Loss_G: 3.0287    D(x): 0.8845    D(G(z)): 0.3597 / 0.0633
[2/5][1500/1583]    Loss_D: 0.6132    Loss_G: 2.0735    D(x): 0.6484    D(G(z)): 0.1048 / 0.1586
[2/5][1550/1583]    Loss_D: 0.6960    Loss_G: 3.9634    D(x): 0.8539    D(G(z)): 0.3680 / 0.0252
[3/5][0/1583]    Loss_D: 1.1915    Loss_G: 2.0190    D(x): 0.4046    D(G(z)): 0.0497 / 0.1893
[3/5][50/1583]    Loss_D: 0.5441    Loss_G: 2.6981    D(x): 0.8116    D(G(z)): 0.2500 / 0.0867
[3/5][100/1583]    Loss_D: 0.8563    Loss_G: 3.1579    D(x): 0.7234    D(G(z)): 0.3582 / 0.0524
[3/5][150/1583]    Loss_D: 0.5450    Loss_G: 2.2794    D(x): 0.7424    D(G(z)): 0.1788 / 0.1285
[3/5][200/1583]    Loss_D: 1.5069    Loss_G: 5.1980    D(x): 0.9165    D(G(z)): 0.6954 / 0.0112
[3/5][250/1583]    Loss_D: 0.9336    Loss_G: 4.0916    D(x): 0.9397    D(G(z)): 0.5212 / 0.0248
[3/5][300/1583]    Loss_D: 0.4925    Loss_G: 1.6877    D(x): 0.7714    D(G(z)): 0.1743 / 0.2235
[3/5][350/1583]    Loss_D: 0.5780    Loss_G: 2.4934    D(x): 0.8048    D(G(z)): 0.2654 / 0.1000
[3/5][400/1583]    Loss_D: 0.6465    Loss_G: 2.2793    D(x): 0.7224    D(G(z)): 0.2318 / 0.1355
[3/5][450/1583]    Loss_D: 0.7550    Loss_G: 1.4872    D(x): 0.5877    D(G(z)): 0.1296 / 0.2714
[3/5][500/1583]    Loss_D: 0.5376    Loss_G: 3.1298    D(x): 0.8589    D(G(z)): 0.2911 / 0.0580
[3/5][550/1583]    Loss_D: 0.7182    Loss_G: 1.3164    D(x): 0.6177    D(G(z)): 0.1423 / 0.3221
[3/5][600/1583]    Loss_D: 0.8334    Loss_G: 3.9036    D(x): 0.8838    D(G(z)): 0.4598 / 0.0299
[3/5][650/1583]    Loss_D: 0.5437    Loss_G: 2.0172    D(x): 0.8044    D(G(z)): 0.2493 / 0.1646
[3/5][700/1583]    Loss_D: 0.5019    Loss_G: 2.5320    D(x): 0.8308    D(G(z)): 0.2389 / 0.1014
[3/5][750/1583]    Loss_D: 0.6720    Loss_G: 1.9764    D(x): 0.7117    D(G(z)): 0.2307 / 0.1763
[3/5][800/1583]    Loss_D: 0.5250    Loss_G: 2.3763    D(x): 0.7334    D(G(z)): 0.1590 / 0.1205
[3/5][850/1583]    Loss_D: 0.9347    Loss_G: 3.8405    D(x): 0.8708    D(G(z)): 0.5005 / 0.0300
[3/5][900/1583]    Loss_D: 0.7929    Loss_G: 3.0115    D(x): 0.7821    D(G(z)): 0.3699 / 0.0696
[3/5][950/1583]    Loss_D: 0.5819    Loss_G: 1.6618    D(x): 0.6905    D(G(z)): 0.1401 / 0.2291
[3/5][1000/1583]    Loss_D: 0.6286    Loss_G: 2.7553    D(x): 0.8014    D(G(z)): 0.3004 / 0.0802
[3/5][1050/1583]    Loss_D: 0.7823    Loss_G: 3.1819    D(x): 0.7950    D(G(z)): 0.3752 / 0.0611
[3/5][1100/1583]    Loss_D: 0.8769    Loss_G: 3.5589    D(x): 0.9221    D(G(z)): 0.4930 / 0.0420
[3/5][1150/1583]    Loss_D: 0.8988    Loss_G: 5.1972    D(x): 0.9245    D(G(z)): 0.5103 / 0.0088
[3/5][1200/1583]    Loss_D: 0.9517    Loss_G: 4.5984    D(x): 0.8263    D(G(z)): 0.4784 / 0.0145
[3/5][1250/1583]    Loss_D: 0.8584    Loss_G: 3.2372    D(x): 0.8411    D(G(z)): 0.4431 / 0.0538
[3/5][1300/1583]    Loss_D: 0.9749    Loss_G: 4.0050    D(x): 0.9209    D(G(z)): 0.5348 / 0.0284
[3/5][1350/1583]    Loss_D: 1.0938    Loss_G: 0.3283    D(x): 0.4292    D(G(z)): 0.0768 / 0.7444
[3/5][1400/1583]    Loss_D: 0.6101    Loss_G: 2.1225    D(x): 0.7091    D(G(z)): 0.1894 / 0.1491
[3/5][1450/1583]    Loss_D: 0.5325    Loss_G: 2.2364    D(x): 0.6970    D(G(z)): 0.1181 / 0.1433
[3/5][1500/1583]    Loss_D: 0.8403    Loss_G: 3.3129    D(x): 0.8605    D(G(z)): 0.4357 / 0.0549
[3/5][1550/1583]    Loss_D: 0.5560    Loss_G: 2.1209    D(x): 0.7575    D(G(z)): 0.2112 / 0.1460
[4/5][0/1583]    Loss_D: 1.2497    Loss_G: 0.6124    D(x): 0.3445    D(G(z)): 0.0367 / 0.5943
[4/5][50/1583]    Loss_D: 1.0927    Loss_G: 0.9180    D(x): 0.4104    D(G(z)): 0.0462 / 0.4446
[4/5][100/1583]    Loss_D: 0.9225    Loss_G: 0.4822    D(x): 0.4784    D(G(z)): 0.0647 / 0.6530
[4/5][150/1583]    Loss_D: 0.5889    Loss_G: 2.3969    D(x): 0.7178    D(G(z)): 0.1811 / 0.1196
[4/5][200/1583]    Loss_D: 0.5160    Loss_G: 2.1426    D(x): 0.6925    D(G(z)): 0.1009 / 0.1535
[4/5][250/1583]    Loss_D: 0.6147    Loss_G: 1.8767    D(x): 0.6149    D(G(z)): 0.0715 / 0.1976
[4/5][300/1583]    Loss_D: 0.7196    Loss_G: 1.8276    D(x): 0.6679    D(G(z)): 0.2232 / 0.1950
[4/5][350/1583]    Loss_D: 0.6552    Loss_G: 3.0728    D(x): 0.8624    D(G(z)): 0.3590 / 0.0594
[4/5][400/1583]    Loss_D: 0.4485    Loss_G: 2.0510    D(x): 0.7991    D(G(z)): 0.1797 / 0.1553
[4/5][450/1583]    Loss_D: 0.5292    Loss_G: 2.2283    D(x): 0.7730    D(G(z)): 0.2061 / 0.1299
[4/5][500/1583]    Loss_D: 0.9299    Loss_G: 2.7818    D(x): 0.7656    D(G(z)): 0.4220 / 0.0857
[4/5][550/1583]    Loss_D: 0.4178    Loss_G: 3.2685    D(x): 0.8952    D(G(z)): 0.2436 / 0.0496
[4/5][600/1583]    Loss_D: 0.8358    Loss_G: 1.6605    D(x): 0.5853    D(G(z)): 0.1576 / 0.2465
[4/5][650/1583]    Loss_D: 2.2432    Loss_G: 2.5945    D(x): 0.9064    D(G(z)): 0.8167 / 0.1101
[4/5][700/1583]    Loss_D: 0.5702    Loss_G: 2.5894    D(x): 0.7715    D(G(z)): 0.2206 / 0.0984
[4/5][750/1583]    Loss_D: 1.3677    Loss_G: 1.1562    D(x): 0.3420    D(G(z)): 0.0918 / 0.3838
[4/5][800/1583]    Loss_D: 0.5114    Loss_G: 2.1584    D(x): 0.7484    D(G(z)): 0.1617 / 0.1515
[4/5][850/1583]    Loss_D: 1.0217    Loss_G: 0.9296    D(x): 0.4531    D(G(z)): 0.0982 / 0.4508
[4/5][900/1583]    Loss_D: 0.9362    Loss_G: 2.2130    D(x): 0.7036    D(G(z)): 0.3690 / 0.1460
[4/5][950/1583]    Loss_D: 0.5356    Loss_G: 1.9746    D(x): 0.7591    D(G(z)): 0.1908 / 0.1650
[4/5][1000/1583]    Loss_D: 0.6062    Loss_G: 3.4775    D(x): 0.8733    D(G(z)): 0.3460 / 0.0390
[4/5][1050/1583]    Loss_D: 0.7407    Loss_G: 1.4774    D(x): 0.5800    D(G(z)): 0.1078 / 0.2783
[4/5][1100/1583]    Loss_D: 0.6590    Loss_G: 1.7840    D(x): 0.6090    D(G(z)): 0.0856 / 0.2158
[4/5][1150/1583]    Loss_D: 1.1789    Loss_G: 4.3504    D(x): 0.8746    D(G(z)): 0.5924 / 0.0210
[4/5][1200/1583]    Loss_D: 0.5922    Loss_G: 2.0222    D(x): 0.7280    D(G(z)): 0.1955 / 0.1713
[4/5][1250/1583]    Loss_D: 0.6910    Loss_G: 1.3199    D(x): 0.6022    D(G(z)): 0.1046 / 0.3194
[4/5][1300/1583]    Loss_D: 0.6726    Loss_G: 2.6413    D(x): 0.8492    D(G(z)): 0.3557 / 0.0972
[4/5][1350/1583]    Loss_D: 0.7808    Loss_G: 3.7603    D(x): 0.8605    D(G(z)): 0.4220 / 0.0317
[4/5][1400/1583]    Loss_D: 0.6173    Loss_G: 3.5769    D(x): 0.9113    D(G(z)): 0.3699 / 0.0392
[4/5][1450/1583]    Loss_D: 0.4523    Loss_G: 2.5795    D(x): 0.7925    D(G(z)): 0.1731 / 0.0977
[4/5][1500/1583]    Loss_D: 0.4865    Loss_G: 2.3129    D(x): 0.8465    D(G(z)): 0.2489 / 0.1228
[4/5][1550/1583]    Loss_D: 1.0133    Loss_G: 4.2263    D(x): 0.8779    D(G(z)): 0.5241 / 0.0254

Results

最後に、結果を確認しましょう。ここでは、3つの異なる結果を見ていきます。
まず、トレーニング中に $D$ と $G$ の損失がどのように変化したかを確認します。
次に、各エポックのfixed_noiseによる $G$ の出力を視覚化します。
最後に、$G$ で生成した画像と、実際の画像を並べてみます。

損失とトレーニングの反復

以下は、D と G の損失とトレーニングの反復のプロットです。

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

ダウンロード2.png

Gの進行の視覚化

トレーニングの各エポックで、fixed_noiseで生成器の出力画像を保存しました。これで、Gのトレーニングの進行をアニメーションで視覚化できます。再生ボタンを押してアニメーションを開始します。(Qiitaではアニメーションを貼り付けられませんのでGIFでアップロードしています)

#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

output.gif

実画像と偽画像

最後に、いくつかの実際の画像と偽の画像を並べて見てみましょう。

# データローダから実際の画像のバッチを取得します
real_batch = next(iter(dataloader))

# 実際の画像をプロットします
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# 最後のエポックからの偽の画像をプロットします
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

ダウンロード4.png

Where to Go Next

チュートリアルは以上ですが、次にできることがいくつかあります。

  • より長くトレーニングし、結果が良くなるか確認します。
  • このモデルを変更して別のデータセットを取得し、場合によっては画像のサイズとモデルアーキテクチャを変更します。
  • ここで他のGANプロジェクトをチェックしてください。
  • 音楽を生成するGANを作成します。

終わりに

今回は、GAN(敵対的生成ネットワーク)を学びました。
次回は「Audio I/O and Pre-Processing with torchaudio」を進めてみたいと思います。

履歴

2021/2/3 初版公開

3
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?