はじめに
前回に引き続き、PyTorch 公式チュートリアル の第11弾です。
今回は DCGAN Tutorial を進めます。
DCGAN Tutorial
Introduction
このチュートリアルでは、DCGANを紹介します。実在の有名人の画像をもとに新しい有名人の画像を生成する敵対的生成ネットワーク(GAN)をトレーニングします。
このチュートリアルのコードは、pytorch/examples にある dcgan の実装を利用しています。GANの事前知識は必要ありません。処理時間の都合上、GPUを利用したほうがよいでしょう。
それでは始めてみましょう。
Generative Adversarial Networks
What is a GAN?
GANは、ニューラルネットワークのモデルにトレーニングデータの分布を学習させるためのフレームワークで、学習した分布から新しいデータを生成できます。
GANは、2014年にイアン・グッドフェロー氏によって考え出され、論文 Generative Adversarial Nets で最初に発表されました。
GANは、**生成器(generator)と識別器(discriminator)**の2つのモデルで構成されています。
生成器は、トレーニング画像を元に「偽の」画像を生成します。
一方、識別器は、画像が実際のトレーニング画像であるか、生成器からの偽の画像であるかを判別します。
生成器は識別器が判断できないような良い偽物を生成することで学習し、識別器は実際の画像と偽の画像を正しく分類するように学習します。
この学習のゴールは、生成器がトレーニングデータのように見える完全な偽物を生成しているときであり、識別器は、生成器の出力が本物か偽物かを常に50%の信頼度で推測します。
それでは、識別器から見ていきましょう。
$x$ を画像を表すデータとします。
$D(x)$ は、$x$が本物の画像である確率(スカラー値)を出力する識別器です。ここでは、画像を扱っているため、$D(x)$ への入力はCHW(色、縦、横)で、このチュートリアルでは 3x64x64 の画像です。
$D(x)$ のスカラー値は $x$ が本物の画像の場合は大きい値になり、生成器が作製した画像の場合は小さい値になります。$D(x)$ は通常の 2クラス分類器と考えることもできます。
次に生成器を見ていきます。$z$ を標準正規分布からサンプリングされた潜在空間ベクトルとします。 $G(z)$ は、潜在ベクトル $z$ をデータ空間にマッピングするジェネレーター関数を表します。 $G$ の目標は、トレーニングデータの取得元の分布 ($p_{data}$) を推定して、推定された分布 ($p_g$) から偽のサンプルを生成できるようにすることです。
(ここでいう潜在ベクトルは、n次元の標準化されたランダム値のことです。このチュートリアルでは 100次元としています。簡単のため2次元で考えると、点(1,1)を入力すると1つの画像が作成され、点(3,2)を入力すると別の画像が作成されます。以下の図のように入力の点を徐々に変化させると、生成される画像も徐々に変化するイメージです。※実際には標準化されているため(1,1)や(3,2)の点はありません)
$D(G(z))$ は、ジェネレータ $G$ の出力が実像である確率(スカラー)です。 グッドフェロー氏の論文で説明されているように、$D$ と $G$ は、$D$ が実数と偽物を正しく分類する確率($logD(x)$)を最大化しようとし、$G$ が出力が偽物であると予測する確率( $log(1-D(G(x)))$ )を最小化しようとするミニマックスゲームをプレイします。論文から、GAN損失関数は以下で定義されます。
$\underset{G}{\text{min}} \underset{D}{\text{max}}V(D,G) = \mathbb{E}_{x\sim p_{data}(x)}\big[logD(x)\big] + \mathbb{E}_{z\sim p_{z}(z)}\big[log(1-D(G(z)))\big] $
理論的には、このミニマックスゲームの解は、$p_g = p_{data}$ であり、識別器は入力が本物か偽物かをランダムに推測します。ただし、GANの収束理論はまだ活発に研究されており、実際にはモデルが常にこの時点までトレーニングできるとは限りません。
What is a DCGAN?
DCGAN(Deep Convolutional GAN) は、上記のGANを拡張したものですが、識別器と生成器でそれぞれ畳み込み層と畳み込み転置層を使用する点が異なります。
これは、Radford氏らによって、論文「Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks」で最初に説明されました。
識別器は、strided convolution 層(畳み込み層の一種)、Batch Normalization 層、および LeakyReLU activation(活性化関数) で構成されています。入力は 3x64x64 の画像で、出力は入力データ(画像)が実在の写真である確率(スカラー値)です。
生成器は、convolutional-transpose 層、Batch Normalization 層、および ReLU activation(活性化関数) で構成されます。入力は、標準正規分布から抽出された潜在ベクトル $z$で、出力は 3x64x64 のRGB画像です。
論文では、オプティマイザーのセットアップ方法、損失関数の計算方法、モデルの重みの初期化方法に関するヒントもいくつか示しています。これらはすべて、次のセクションで説明します。
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
# 再現性のためにランダムシードを設定する
manualSeed = 999
#manualSeed = random.randint(1、10000)#新しい結果が必要な場合に使用
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
Inputs
入力を定義します:
- dataroot - データセットフォルダのルートへのパス。次のセクションでデータセットについて詳しく説明します。
- workers - DataLoaderでデータをロードするためのワーカースレッドの数。
- batch_size - トレーニングで使用されるバッチサイズ。 DCGANの論文では128のバッチサイズを使用しています。
- image_size - トレーニングに使用される画像の空間サイズ。この実装のデフォルトは64x64です。別のサイズが必要な場合は、DとGの構造を変更する必要があります。詳細はこちらをご覧ください。
- nc - 入力画像のカラーチャンネルの数。カラー画像の場合は 3 です。
- nz - 潜在ベクトルの長さ。
- ngf - 生成器を介して伝播される特徴マップの深さに関連します。
- ndf - 識別器を介して伝播される特徴マップの深さを設定します。
- num_epochs - 実行するトレーニングエポックの数。エポック数を大きくすると、より良い結果につながりますが、時間がかかります。
- lr - トレーニングの学習率。 DCGANの論文で説明されているように、この数値は0.0002である必要があります。
- beta1 - Adamオプティマイザーのbeta1ハイパーパラメーター。論文で説明されているように、この数値は0.5である必要があります。
- ngpu - 使用可能なGPUの数。0の場合、CPUモードで実行されます。この数が0より大きい場合、その数のGPUで実行されます。
# データセットのルートディレクトリ
dataroot = "data/celeba"
# データローダーのワーカー数
workers = 2
# トレーニングのバッチサイズ
batch_size = 128
# トレーニング画像の空間サイズ。
# すべての画像はトランスフォーマーを使用してこのサイズに変更されます。
image_size = 64
# トレーニング画像のチャネル数。カラー画像の場合は「3」
nc = 3
# 潜在ベクトル z のサイズ(つまり、ジェネレータ入力のサイズ)
nz = 100
# 生成器の feature map のサイズ
ngf = 64
# 識別器の feature map のサイズ
ndf = 64
# エポック数
num_epochs = 5
# 学習率
lr = 0.0002
# Adam オプティマイザのBeta1ハイパーパラメータ
beta1 = 0.5
# 使用可能なGPUの数。0の場合、CPUモードで実行されます
ngpu = 1
Data
このチュートリアルでは、リンク先のサイト またはGoogleドライブからダウンロードできる Celeb-A Faces dataset を使用します。データセットは、img_align_celeba.zipという名前のファイルとしてダウンロードできます。ダウンロードしたら、celebaという名前のディレクトリを作成し、zipファイルをそのディレクトリに解凍します。次に、このノートブックの dataroot にcelebaディレクトリを設定します。
ディレクトリ構造は次のようになります。
/path/to/celeba
-> img_align_celeba
-> 188242.jpg
-> 173822.jpg
-> 284702.jpg
-> 537394.jpg
...
ここでは、以下の処理を行います。
- データセットを作成する
- データローダーを作成する
- 実行するデバイスを設定する
- 最後にトレーニングデータの一部を視覚化する。
データセットのルートフォルダにはサブディレクトリが必要です。
# 画像フォルダデータセットは、以下で設定した方法で使用できます。
# データセットを作成する
dataset = dset.ImageFolder(root=dataroot,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
# データローダーを作成する
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=True, num_workers=workers)
# 実行するデバイスを決定する
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
# トレーニング画像をプロットする
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
Implementation(実装)
データセットを準備したら、実装を行います。
重みの初期化から始めて、生成器、識別器、損失関数、トレーニングについて詳しく説明します。
Weight Initialization(重みの初期化)
DCGANの論文では、モデルの重みを平均 0、標準偏差 0.02 の正規分布からランダムに初期化することが明示されています。
weights_init関数は、初期化されたモデルを入力として受け取り、すべての畳み込み層、および Batch Normalization 層を再初期化します。
この関数は、初期化直後のモデルに適用されます。
# G(生成器)とD(識別器)の重みの初期化
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
Generator(生成器)
ジェネレータ $G$は、潜在空間ベクトル($z$)をデータ空間にマッピングするように設計されています。データは画像のため、$z$ をデータ空間に変換すると、最終的にトレーニング画像と同じサイズ(3x64x64)のRGB画像が作成されます。
実際には、BatchNorm2d レイヤーおよび ReLU アクティベーションとペアになっている、2次元転置畳み込み層(ConvTranspose2d)によって実現されます。
ジェネレータの出力は、tanh関数で供給され、$[-1,1]$の入力データ範囲に戻ります。
DCGANの論文のジェネレーター画像を以下に示します。
入力セクションで設定した入力(nz、ngf、およびnc)がコードのジェネレーターアーキテクチャにどのように影響するかに注意してください。 nzは$z$入力ベクトルの長さ、ngfはジェネレータを介して伝播される特徴マップのサイズに関連し、ncは出力画像のチャネル数です(RGB画像の場合は3に設定されます)。以下はジェネレーターのコードです。
# Generator (生成器)
class Generator(nn.Module):
def __init__(self, ngpu):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# 入力は Z で、畳み込み層に渡されます
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# サイズ (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# サイズ (ngf*4) x 8 x 8
nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# サイズ (ngf*2) x 16 x 16
nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# サイズ (ngf) x 32 x 32
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# サイズ (nc) x 64 x 64
)
def forward(self, input):
return self.main(input)
ジェネレータ(生成器)をインスタンス化し、weights_init関数を適用します。
出力されたモデルをチェックし、ジェネレータがどのように構成されているか確認してください。
# ジェネレーターを作成します
netG = Generator(ngpu).to(device)
# 必要に応じてGPUを使用します
if (device.type == 'cuda') and (ngpu > 1):
netG = nn.DataParallel(netG, list(range(ngpu)))
# weights_init関数を適用して、すべての重みを平均「0」、標準偏差「0.02」でランダムに初期化します。
netG.apply(weights_init)
# モデルを出力します
print(netG)
Generator(
(main): Sequential(
(0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)
(9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(11): ReLU(inplace=True)
(12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(13): Tanh()
)
)
Discriminator(識別器)
前述のように、識別器 $D$は、画像を入力として受け取り、入力画像が(偽物ではなく)本物であるという確率を出力する二値分類ネットワークです。
$D$ は3x64x64の入力画像を取得し、それを一連のConv2d、BatchNorm2d、およびLeakyReLU 層で処理し、Sigmoid関数を介して確率を出力します。 このアーキテクチャは、必要に応じて、より多くのレイヤーに拡張できますが、Strided Convolution、BatchNorm、および LeakyReLU を使用することが重要です。
DCGANの論文では、ダウンサンプリングにプーリングではなく、Strided Convolution を使用することを勧めています。また、batch norm と leaky relu も、$G$ と $D$ の両方の学習プロセスと勾配の計算に重要です。
以下は識別器のコードです。
class Discriminator(nn.Module):
def __init__(self, ngpu):
super(Discriminator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# 入力は (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# サイズ (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# サイズ (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# サイズ (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# サイズ (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input)
これで、生成器と同様に、識別器を作成し、weights_init関数を適用して、モデルの構造を出力できます。
# Create the Discriminator
# 識別器を作成します
netD = Discriminator(ngpu).to(device)
# Handle multi-gpu if desired
# 必要に応じてGPUを使用します
if (device.type == 'cuda') and (ngpu > 1):
netD = nn.DataParallel(netD, list(range(ngpu)))
# Apply the weights_init function to randomly initialize all weights
# to mean=0, stdev=0.2.
# weights_init関数を適用して、すべての重みを平均「0」、標準偏差「0.02」でランダムに初期化します。
netD.apply(weights_init)
# Print the model
# モデルを出力します
print(netD)
Discriminator(
(main): Sequential(
(0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(10): LeakyReLU(negative_slope=0.2, inplace=True)
(11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
(12): Sigmoid()
)
)
Loss Functions and Optimizers
損失関数とオプティマイザで、生成器($G$)、識別器($D$)の学習方法を指定できます。
損失関数には PyTorch で次のように定義されているバイナリクロスエントロピー損失関数BCELossを使用します。
$\qquad \ell(x, y) = L = {l_1,\dots,l_N}^\top, \quad l_n = - \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right]$
この損失関数の対数成分 $log(D(x))$ と $log(1-D(G(z)))$ がどのような計算になるか注意してください。損失関数に渡す $y$ の値は呼び出し元で指定できます。$y$ (正解ラベル)に何を指定するかで計算するコンポーネントを選択できることが重要です。
($y_n$ に正解ラベル 1 を指定すると誤差関数の第1項目 $\log x_n$ が残り、$y_n$ に不正解ラベル 0 を指定すると第2項目 $\log (1 - x_n)$ が残ります。)
次に、実際のラベルを1、偽のラベルを0と定義します。これは、$D$ と $G$ の損失を計算するときに使用されます。
最後に、$D$ と$G$ の2つのオプティマイザを設定します。 DCGANの論文で指定されているように、どちらも学習率 0.0002 および Beta1 = 0.5 の Adam オプティマイザです。
ジェネレータの学習の進行状況を追跡するために、標準正規分布からランダムに抽出された潜在ベクトルの固定値 fixed_noise を生成します。トレーニングループでは、このfixed_noiseを定期的に $G$ に入力し、ループ中に固定値 fixed_noise から画像が形成されるのを確認します。
# BCELoss関数を初期化します
criterion = nn.BCELoss()
# ジェネレータの進行を視覚化するために使用する潜在ベクトルを作成します
fixed_noise = torch.randn(64, nz, 1, 1, device=device)
# トレーニング中に本物のラベルと偽のラベルのルールを設定します
real_label = 1.
fake_label = 0.
# G と D に Adam オプティマイザを設定する
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
Training
GANフレームワークのすべてを定義したので、トレーニングを実行できます。
ただし、パラメータの設定が正しくないとうまく学習できず、何が悪かったのかもほとんど分かりません。
ここでは、ganhacks に示されているいくつかのベストプラクティスを順守しながら、Goodfellowの論文のアルゴリズム1に従い、トレーニングします。
つまり、「実際の画像と偽の画像に異なるミニバッチを作成」し、「$logD(G(z))$ を最大化する」ように $G$ の目的関数を調整します。
トレーニングは2つの主要な部分に分かれています。パート1は識別器のパラメータを更新し、パート2は生成器を更新します。
Part 1 - Train the Discriminator
識別器をトレーニングする目的は、入力画像を本物か偽物を正しく分類する確率を最大化することです。
グッドフェロー氏の論文では、「確率的勾配を上昇させることによって識別器を更新する」ことが記述されています。実際には、$log(D(x)) + log(1-D(G(z)))$ を最大化する必要があります。
ganhacksとは異なるやり方ですが、以下の2つのステップで実行します。
1つ目のステップは、トレーニングデータセットから実在の画像のバッチを作成し、
$D$ を順伝播し、損失($log(D(x))$)を計算してから逆伝播の勾配を計算します。
2つ目のステップは、生成器を使用して偽の画像のバッチを作成し、$D$ を順伝播、損失($log(1-D(G(z)))$)を計算してから逆伝播の勾配を計算します。
2つのステップの両方で蓄積された勾配を使用して、識別器のオプティマイザの step を実行します。
Part 2 - Train the Generator
先の論文で述べているように、より良い偽物を生成するために、$log(1-D(G(z)))$ を最小化することによって生成器をトレーニングしたいと思います。
しかしながら、トレーニングの初期段階でうまく学習されないことがグッドフェロー氏によって示されています。
代わりに $log(D(G(z)))$ を最大化するように修正します。
以下のコードでは、パート1の生成器の出力を識別器で分類し、実際のラベルを正解ラベルとして $G$ の損失を計算し、逆伝播で $G$ の勾配を計算、最後にオプティマイザで$G$のパラメーターを更新することでこれを実現します。
損失関数の正解ラベルとして実際のラベルを使用することは直感に反するように思われるかもしれませんが、これにより、BCELossの $log(x)$ 部分($log(1-x)$部分ではなく)を使用できるようになります。
最後に、いくつかの統計レポートを作成しつつ、各エポックの終わりに fixed_noise をもとに生成器で画像を生成し、Gのトレーニングの進行状況を確認します。
報告されるトレーニング統計は次のとおりです。
- Loss_D - 識別器の損失。実在の画像を識別した時の損失 $log(D(x))$ と偽の画像を識別した時の損失 $log(1-D(G(z)))$ の合計値($log(D(x)) + log(1-D(G(z)))$)。(※ チュートリアル原文の式が誤っている気がしますがどうでしょうか。)
- Loss_G - $log(D(G(z)))$ として計算される生成器の損失
- D(x) - 実際の画像の識別器の(バッチ全体の)平均。これは1近くから始まり、Gが良くなると理論的には0.5に収束するはずです。
- D(G(z)) - 偽の画像の識別器の平均。最初の番号は$D$が更新される前であり、2番目の番号は$D$が更新された後です。これらの数値は0付近から始まり、$G$が良くなるにつれて0.5に収束するはずです。
注意 実行するエポックの数やデータセットからデータを削除したかどうかによっては、この手順に時間がかかる場合があります。
※Colaboratory の環境では入力データ(画像)をアップロードしきれず、実行できませんでした。自分のPC環境(古いPCでGPUはGTX 1050Ti です)だと、画像約20万枚、エポック数 5 で1時間ちょっとかかりました。
# トレーニングループ
# 進捗状況を追跡するためのリスト
img_list = []
G_losses = []
D_losses = []
iters = 0
print("Starting Training Loop...")
# エポックごとのループ
for epoch in range(num_epochs):
# データローダーのバッチごとのループ
for i, data in enumerate(dataloader, 0):
############################
# (1) Dネットワークの更新:log(D(x)) + log(1 - D(G(z))) を最大化します
###########################
## 実在の画像でトレーニングします
netD.zero_grad()
# バッチのフォーマット
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
# 実在の写真で D の順伝播させます
output = netD(real_cpu).view(-1)
# 損失を計算します
errD_real = criterion(output, label)
# 逆伝播でDの勾配を計算します
errD_real.backward()
D_x = output.mean().item()
## 偽の画像でトレーニングします
# 潜在ベクトルのバッチを生成します
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Gで偽の画像を生成します
fake = netG(noise)
label.fill_(fake_label)
# 生成した偽画像をDで分類します
output = netD(fake.detach()).view(-1)
# Dの損失を計算します
errD_fake = criterion(output, label)
# 勾配を計算します
errD_fake.backward()
D_G_z1 = output.mean().item()
# 実在の画像の勾配と偽画像の勾配を加算します
errD = errD_real + errD_fake
# Dを更新します
optimizerD.step()
############################
# (2) Gネットワークの更新:log(D(G(z))) を最大化します
###########################
netG.zero_grad()
label.fill_(real_label) # 偽のラベルは生成器の損失にとって本物です
# パラメータ更新後のDを利用して、偽画像を順伝播させます
output = netD(fake).view(-1)
# この出力に基づいてGの損失を計算します
errG = criterion(output, label)
# Gの勾配を計算します
errG.backward()
D_G_z2 = output.mean().item()
# Gを更新します
optimizerG.step()
# トレーニング統計を出力します
if i % 50 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch, num_epochs, i, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
# 後でプロットするために損失を保存します
G_losses.append(errG.item())
D_losses.append(errD.item())
# fixed_noiseによる G の出力を保存し、生成器の精度を確認します
if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
iters += 1
Starting Training Loop...
[0/5][0/1583] Loss_D: 1.7825 Loss_G: 5.0967 D(x): 0.5566 D(G(z)): 0.5963 / 0.0094
[0/5][50/1583] Loss_D: 1.2308 Loss_G: 21.3185 D(x): 0.9840 D(G(z)): 0.5805 / 0.0000
[0/5][100/1583] Loss_D: 0.2915 Loss_G: 6.6904 D(x): 0.9769 D(G(z)): 0.1790 / 0.0028
[0/5][150/1583] Loss_D: 0.3132 Loss_G: 5.4999 D(x): 0.8093 D(G(z)): 0.0210 / 0.0094
[0/5][200/1583] Loss_D: 1.8818 Loss_G: 10.8477 D(x): 0.4127 D(G(z)): 0.0022 / 0.0021
[0/5][250/1583] Loss_D: 0.3169 Loss_G: 3.5397 D(x): 0.8246 D(G(z)): 0.0601 / 0.0547
[0/5][300/1583] Loss_D: 0.3551 Loss_G: 3.1239 D(x): 0.9287 D(G(z)): 0.2033 / 0.0873
[0/5][350/1583] Loss_D: 0.3831 Loss_G: 3.4443 D(x): 0.8077 D(G(z)): 0.1026 / 0.0508
[0/5][400/1583] Loss_D: 0.4064 Loss_G: 4.7648 D(x): 0.8264 D(G(z)): 0.1073 / 0.0158
[0/5][450/1583] Loss_D: 0.5639 Loss_G: 6.6930 D(x): 0.9511 D(G(z)): 0.3329 / 0.0031
[0/5][500/1583] Loss_D: 1.3354 Loss_G: 4.8289 D(x): 0.9831 D(G(z)): 0.6628 / 0.0210
[0/5][550/1583] Loss_D: 0.4430 Loss_G: 5.3935 D(x): 0.8112 D(G(z)): 0.1430 / 0.0082
[0/5][600/1583] Loss_D: 0.4703 Loss_G: 3.0324 D(x): 0.7795 D(G(z)): 0.1257 / 0.0754
[0/5][650/1583] Loss_D: 1.7328 Loss_G: 3.3086 D(x): 0.3306 D(G(z)): 0.0012 / 0.1422
[0/5][700/1583] Loss_D: 0.3655 Loss_G: 5.8728 D(x): 0.9177 D(G(z)): 0.2211 / 0.0053
[0/5][750/1583] Loss_D: 0.2753 Loss_G: 5.8153 D(x): 0.9107 D(G(z)): 0.1408 / 0.0061
[0/5][800/1583] Loss_D: 0.6434 Loss_G: 3.9464 D(x): 0.7840 D(G(z)): 0.2572 / 0.0382
[0/5][850/1583] Loss_D: 0.4208 Loss_G: 3.9550 D(x): 0.8904 D(G(z)): 0.1904 / 0.0348
[0/5][900/1583] Loss_D: 0.1478 Loss_G: 4.9144 D(x): 0.9053 D(G(z)): 0.0304 / 0.0184
[0/5][950/1583] Loss_D: 0.4626 Loss_G: 3.9822 D(x): 0.7520 D(G(z)): 0.0995 / 0.0307
[0/5][1000/1583] Loss_D: 1.3251 Loss_G: 6.6428 D(x): 0.8265 D(G(z)): 0.5140 / 0.0039
[0/5][1050/1583] Loss_D: 0.2756 Loss_G: 4.9520 D(x): 0.9536 D(G(z)): 0.1760 / 0.0161
[0/5][1100/1583] Loss_D: 0.4271 Loss_G: 4.3152 D(x): 0.8585 D(G(z)): 0.1801 / 0.0276
[0/5][1150/1583] Loss_D: 0.4866 Loss_G: 5.8982 D(x): 0.9237 D(G(z)): 0.2801 / 0.0061
[0/5][1200/1583] Loss_D: 0.5454 Loss_G: 3.4798 D(x): 0.7154 D(G(z)): 0.0932 / 0.0484
[0/5][1250/1583] Loss_D: 0.7009 Loss_G: 3.6421 D(x): 0.7268 D(G(z)): 0.2214 / 0.0448
[0/5][1300/1583] Loss_D: 0.9729 Loss_G: 6.5268 D(x): 0.9301 D(G(z)): 0.4758 / 0.0049
[0/5][1350/1583] Loss_D: 0.2977 Loss_G: 4.4779 D(x): 0.8450 D(G(z)): 0.0783 / 0.0217
[0/5][1400/1583] Loss_D: 0.4291 Loss_G: 4.7027 D(x): 0.7973 D(G(z)): 0.1081 / 0.0232
[0/5][1450/1583] Loss_D: 0.5062 Loss_G: 1.7627 D(x): 0.7334 D(G(z)): 0.0914 / 0.2356
[0/5][1500/1583] Loss_D: 0.3683 Loss_G: 2.8416 D(x): 0.7972 D(G(z)): 0.0877 / 0.0884
[0/5][1550/1583] Loss_D: 1.2416 Loss_G: 6.1441 D(x): 0.9658 D(G(z)): 0.6089 / 0.0039
[1/5][0/1583] Loss_D: 0.4863 Loss_G: 4.1000 D(x): 0.8147 D(G(z)): 0.1864 / 0.0311
[1/5][50/1583] Loss_D: 0.5487 Loss_G: 4.3008 D(x): 0.8068 D(G(z)): 0.2321 / 0.0231
[1/5][100/1583] Loss_D: 0.3973 Loss_G: 3.1769 D(x): 0.8267 D(G(z)): 0.1488 / 0.0607
[1/5][150/1583] Loss_D: 0.4199 Loss_G: 3.7225 D(x): 0.8362 D(G(z)): 0.1711 / 0.0388
[1/5][200/1583] Loss_D: 0.3704 Loss_G: 3.1646 D(x): 0.7950 D(G(z)): 0.0742 / 0.0690
[1/5][250/1583] Loss_D: 0.4835 Loss_G: 2.6290 D(x): 0.7500 D(G(z)): 0.0945 / 0.1090
[1/5][300/1583] Loss_D: 1.7134 Loss_G: 0.8583 D(x): 0.3044 D(G(z)): 0.0100 / 0.5134
[1/5][350/1583] Loss_D: 1.2287 Loss_G: 1.1995 D(x): 0.4230 D(G(z)): 0.0291 / 0.4012
[1/5][400/1583] Loss_D: 0.3184 Loss_G: 3.1359 D(x): 0.8466 D(G(z)): 0.1128 / 0.0672
[1/5][450/1583] Loss_D: 0.4920 Loss_G: 2.9625 D(x): 0.7946 D(G(z)): 0.1426 / 0.0824
[1/5][500/1583] Loss_D: 0.4525 Loss_G: 3.3157 D(x): 0.8559 D(G(z)): 0.2031 / 0.0599
[1/5][550/1583] Loss_D: 0.3221 Loss_G: 3.4189 D(x): 0.8521 D(G(z)): 0.1203 / 0.0591
[1/5][600/1583] Loss_D: 0.5004 Loss_G: 3.6299 D(x): 0.8184 D(G(z)): 0.2133 / 0.0486
[1/5][650/1583] Loss_D: 0.4997 Loss_G: 4.4381 D(x): 0.8822 D(G(z)): 0.2777 / 0.0170
[1/5][700/1583] Loss_D: 0.3672 Loss_G: 2.6801 D(x): 0.8161 D(G(z)): 0.1072 / 0.1065
[1/5][750/1583] Loss_D: 0.4129 Loss_G: 2.7585 D(x): 0.8070 D(G(z)): 0.1383 / 0.0915
[1/5][800/1583] Loss_D: 1.3243 Loss_G: 6.4724 D(x): 0.9383 D(G(z)): 0.6237 / 0.0046
[1/5][850/1583] Loss_D: 0.4103 Loss_G: 3.9704 D(x): 0.8653 D(G(z)): 0.2067 / 0.0287
[1/5][900/1583] Loss_D: 0.5130 Loss_G: 2.8125 D(x): 0.7487 D(G(z)): 0.1296 / 0.0867
[1/5][950/1583] Loss_D: 0.3896 Loss_G: 3.7857 D(x): 0.8697 D(G(z)): 0.1867 / 0.0340
[1/5][1000/1583] Loss_D: 0.3782 Loss_G: 4.0259 D(x): 0.8248 D(G(z)): 0.1319 / 0.0324
[1/5][1050/1583] Loss_D: 0.3561 Loss_G: 3.6737 D(x): 0.8014 D(G(z)): 0.0939 / 0.0508
[1/5][1100/1583] Loss_D: 1.1566 Loss_G: 6.8887 D(x): 0.9395 D(G(z)): 0.5788 / 0.0030
[1/5][1150/1583] Loss_D: 0.4082 Loss_G: 2.7026 D(x): 0.8282 D(G(z)): 0.1693 / 0.0954
[1/5][1200/1583] Loss_D: 0.4618 Loss_G: 2.0251 D(x): 0.7042 D(G(z)): 0.0517 / 0.1746
[1/5][1250/1583] Loss_D: 0.4982 Loss_G: 3.7400 D(x): 0.8717 D(G(z)): 0.2661 / 0.0350
[1/5][1300/1583] Loss_D: 0.8765 Loss_G: 2.0860 D(x): 0.5928 D(G(z)): 0.1493 / 0.1935
[1/5][1350/1583] Loss_D: 0.4507 Loss_G: 2.7327 D(x): 0.8017 D(G(z)): 0.1732 / 0.0891
[1/5][1400/1583] Loss_D: 0.6554 Loss_G: 4.4541 D(x): 0.9331 D(G(z)): 0.3904 / 0.0190
[1/5][1450/1583] Loss_D: 0.3707 Loss_G: 3.0559 D(x): 0.8092 D(G(z)): 0.1124 / 0.0752
[1/5][1500/1583] Loss_D: 0.3978 Loss_G: 2.9060 D(x): 0.8130 D(G(z)): 0.1448 / 0.0722
[1/5][1550/1583] Loss_D: 0.6422 Loss_G: 3.3732 D(x): 0.7391 D(G(z)): 0.2337 / 0.0562
[2/5][0/1583] Loss_D: 0.5571 Loss_G: 4.3724 D(x): 0.8466 D(G(z)): 0.2813 / 0.0222
[2/5][50/1583] Loss_D: 0.9171 Loss_G: 3.3607 D(x): 0.5004 D(G(z)): 0.0169 / 0.0584
[2/5][100/1583] Loss_D: 0.5236 Loss_G: 2.8583 D(x): 0.8345 D(G(z)): 0.2470 / 0.0811
[2/5][150/1583] Loss_D: 0.6476 Loss_G: 2.4148 D(x): 0.6520 D(G(z)): 0.0886 / 0.1282
[2/5][200/1583] Loss_D: 0.4759 Loss_G: 2.1868 D(x): 0.7868 D(G(z)): 0.1683 / 0.1432
[2/5][250/1583] Loss_D: 0.4453 Loss_G: 3.5857 D(x): 0.8142 D(G(z)): 0.1816 / 0.0417
[2/5][300/1583] Loss_D: 0.6185 Loss_G: 1.9197 D(x): 0.6361 D(G(z)): 0.0578 / 0.1909
[2/5][350/1583] Loss_D: 0.5656 Loss_G: 1.8746 D(x): 0.6823 D(G(z)): 0.0907 / 0.2027
[2/5][400/1583] Loss_D: 0.5084 Loss_G: 3.8005 D(x): 0.8927 D(G(z)): 0.2975 / 0.0310
[2/5][450/1583] Loss_D: 0.6327 Loss_G: 3.6768 D(x): 0.8792 D(G(z)): 0.3508 / 0.0370
[2/5][500/1583] Loss_D: 1.0236 Loss_G: 0.9627 D(x): 0.5368 D(G(z)): 0.2148 / 0.4402
[2/5][550/1583] Loss_D: 0.5240 Loss_G: 1.8640 D(x): 0.6880 D(G(z)): 0.0900 / 0.1949
[2/5][600/1583] Loss_D: 1.4416 Loss_G: 0.7844 D(x): 0.3551 D(G(z)): 0.0359 / 0.5423
[2/5][650/1583] Loss_D: 0.7049 Loss_G: 2.6899 D(x): 0.7614 D(G(z)): 0.3058 / 0.0898
[2/5][700/1583] Loss_D: 0.6580 Loss_G: 2.1950 D(x): 0.6655 D(G(z)): 0.1447 / 0.1515
[2/5][750/1583] Loss_D: 0.4866 Loss_G: 2.9747 D(x): 0.8271 D(G(z)): 0.2238 / 0.0660
[2/5][800/1583] Loss_D: 0.8614 Loss_G: 3.2034 D(x): 0.8388 D(G(z)): 0.4398 / 0.0602
[2/5][850/1583] Loss_D: 0.5973 Loss_G: 1.9877 D(x): 0.7525 D(G(z)): 0.2234 / 0.1688
[2/5][900/1583] Loss_D: 0.6860 Loss_G: 1.3544 D(x): 0.6354 D(G(z)): 0.1454 / 0.3062
[2/5][950/1583] Loss_D: 0.5878 Loss_G: 2.2449 D(x): 0.7269 D(G(z)): 0.1780 / 0.1410
[2/5][1000/1583] Loss_D: 1.0009 Loss_G: 1.7307 D(x): 0.4573 D(G(z)): 0.0505 / 0.2349
[2/5][1050/1583] Loss_D: 0.7103 Loss_G: 1.6821 D(x): 0.6081 D(G(z)): 0.1171 / 0.2361
[2/5][1100/1583] Loss_D: 1.2062 Loss_G: 1.9829 D(x): 0.5149 D(G(z)): 0.2336 / 0.1943
[2/5][1150/1583] Loss_D: 0.5601 Loss_G: 3.1371 D(x): 0.8733 D(G(z)): 0.3131 / 0.0597
[2/5][1200/1583] Loss_D: 0.4691 Loss_G: 2.6874 D(x): 0.8428 D(G(z)): 0.2337 / 0.0879
[2/5][1250/1583] Loss_D: 0.8933 Loss_G: 3.6654 D(x): 0.8084 D(G(z)): 0.4449 / 0.0375
[2/5][1300/1583] Loss_D: 1.1260 Loss_G: 1.2977 D(x): 0.4179 D(G(z)): 0.0427 / 0.3397
[2/5][1350/1583] Loss_D: 0.5914 Loss_G: 2.9971 D(x): 0.8794 D(G(z)): 0.3353 / 0.0661
[2/5][1400/1583] Loss_D: 0.5957 Loss_G: 1.9661 D(x): 0.6747 D(G(z)): 0.1309 / 0.1766
[2/5][1450/1583] Loss_D: 0.6327 Loss_G: 3.0287 D(x): 0.8845 D(G(z)): 0.3597 / 0.0633
[2/5][1500/1583] Loss_D: 0.6132 Loss_G: 2.0735 D(x): 0.6484 D(G(z)): 0.1048 / 0.1586
[2/5][1550/1583] Loss_D: 0.6960 Loss_G: 3.9634 D(x): 0.8539 D(G(z)): 0.3680 / 0.0252
[3/5][0/1583] Loss_D: 1.1915 Loss_G: 2.0190 D(x): 0.4046 D(G(z)): 0.0497 / 0.1893
[3/5][50/1583] Loss_D: 0.5441 Loss_G: 2.6981 D(x): 0.8116 D(G(z)): 0.2500 / 0.0867
[3/5][100/1583] Loss_D: 0.8563 Loss_G: 3.1579 D(x): 0.7234 D(G(z)): 0.3582 / 0.0524
[3/5][150/1583] Loss_D: 0.5450 Loss_G: 2.2794 D(x): 0.7424 D(G(z)): 0.1788 / 0.1285
[3/5][200/1583] Loss_D: 1.5069 Loss_G: 5.1980 D(x): 0.9165 D(G(z)): 0.6954 / 0.0112
[3/5][250/1583] Loss_D: 0.9336 Loss_G: 4.0916 D(x): 0.9397 D(G(z)): 0.5212 / 0.0248
[3/5][300/1583] Loss_D: 0.4925 Loss_G: 1.6877 D(x): 0.7714 D(G(z)): 0.1743 / 0.2235
[3/5][350/1583] Loss_D: 0.5780 Loss_G: 2.4934 D(x): 0.8048 D(G(z)): 0.2654 / 0.1000
[3/5][400/1583] Loss_D: 0.6465 Loss_G: 2.2793 D(x): 0.7224 D(G(z)): 0.2318 / 0.1355
[3/5][450/1583] Loss_D: 0.7550 Loss_G: 1.4872 D(x): 0.5877 D(G(z)): 0.1296 / 0.2714
[3/5][500/1583] Loss_D: 0.5376 Loss_G: 3.1298 D(x): 0.8589 D(G(z)): 0.2911 / 0.0580
[3/5][550/1583] Loss_D: 0.7182 Loss_G: 1.3164 D(x): 0.6177 D(G(z)): 0.1423 / 0.3221
[3/5][600/1583] Loss_D: 0.8334 Loss_G: 3.9036 D(x): 0.8838 D(G(z)): 0.4598 / 0.0299
[3/5][650/1583] Loss_D: 0.5437 Loss_G: 2.0172 D(x): 0.8044 D(G(z)): 0.2493 / 0.1646
[3/5][700/1583] Loss_D: 0.5019 Loss_G: 2.5320 D(x): 0.8308 D(G(z)): 0.2389 / 0.1014
[3/5][750/1583] Loss_D: 0.6720 Loss_G: 1.9764 D(x): 0.7117 D(G(z)): 0.2307 / 0.1763
[3/5][800/1583] Loss_D: 0.5250 Loss_G: 2.3763 D(x): 0.7334 D(G(z)): 0.1590 / 0.1205
[3/5][850/1583] Loss_D: 0.9347 Loss_G: 3.8405 D(x): 0.8708 D(G(z)): 0.5005 / 0.0300
[3/5][900/1583] Loss_D: 0.7929 Loss_G: 3.0115 D(x): 0.7821 D(G(z)): 0.3699 / 0.0696
[3/5][950/1583] Loss_D: 0.5819 Loss_G: 1.6618 D(x): 0.6905 D(G(z)): 0.1401 / 0.2291
[3/5][1000/1583] Loss_D: 0.6286 Loss_G: 2.7553 D(x): 0.8014 D(G(z)): 0.3004 / 0.0802
[3/5][1050/1583] Loss_D: 0.7823 Loss_G: 3.1819 D(x): 0.7950 D(G(z)): 0.3752 / 0.0611
[3/5][1100/1583] Loss_D: 0.8769 Loss_G: 3.5589 D(x): 0.9221 D(G(z)): 0.4930 / 0.0420
[3/5][1150/1583] Loss_D: 0.8988 Loss_G: 5.1972 D(x): 0.9245 D(G(z)): 0.5103 / 0.0088
[3/5][1200/1583] Loss_D: 0.9517 Loss_G: 4.5984 D(x): 0.8263 D(G(z)): 0.4784 / 0.0145
[3/5][1250/1583] Loss_D: 0.8584 Loss_G: 3.2372 D(x): 0.8411 D(G(z)): 0.4431 / 0.0538
[3/5][1300/1583] Loss_D: 0.9749 Loss_G: 4.0050 D(x): 0.9209 D(G(z)): 0.5348 / 0.0284
[3/5][1350/1583] Loss_D: 1.0938 Loss_G: 0.3283 D(x): 0.4292 D(G(z)): 0.0768 / 0.7444
[3/5][1400/1583] Loss_D: 0.6101 Loss_G: 2.1225 D(x): 0.7091 D(G(z)): 0.1894 / 0.1491
[3/5][1450/1583] Loss_D: 0.5325 Loss_G: 2.2364 D(x): 0.6970 D(G(z)): 0.1181 / 0.1433
[3/5][1500/1583] Loss_D: 0.8403 Loss_G: 3.3129 D(x): 0.8605 D(G(z)): 0.4357 / 0.0549
[3/5][1550/1583] Loss_D: 0.5560 Loss_G: 2.1209 D(x): 0.7575 D(G(z)): 0.2112 / 0.1460
[4/5][0/1583] Loss_D: 1.2497 Loss_G: 0.6124 D(x): 0.3445 D(G(z)): 0.0367 / 0.5943
[4/5][50/1583] Loss_D: 1.0927 Loss_G: 0.9180 D(x): 0.4104 D(G(z)): 0.0462 / 0.4446
[4/5][100/1583] Loss_D: 0.9225 Loss_G: 0.4822 D(x): 0.4784 D(G(z)): 0.0647 / 0.6530
[4/5][150/1583] Loss_D: 0.5889 Loss_G: 2.3969 D(x): 0.7178 D(G(z)): 0.1811 / 0.1196
[4/5][200/1583] Loss_D: 0.5160 Loss_G: 2.1426 D(x): 0.6925 D(G(z)): 0.1009 / 0.1535
[4/5][250/1583] Loss_D: 0.6147 Loss_G: 1.8767 D(x): 0.6149 D(G(z)): 0.0715 / 0.1976
[4/5][300/1583] Loss_D: 0.7196 Loss_G: 1.8276 D(x): 0.6679 D(G(z)): 0.2232 / 0.1950
[4/5][350/1583] Loss_D: 0.6552 Loss_G: 3.0728 D(x): 0.8624 D(G(z)): 0.3590 / 0.0594
[4/5][400/1583] Loss_D: 0.4485 Loss_G: 2.0510 D(x): 0.7991 D(G(z)): 0.1797 / 0.1553
[4/5][450/1583] Loss_D: 0.5292 Loss_G: 2.2283 D(x): 0.7730 D(G(z)): 0.2061 / 0.1299
[4/5][500/1583] Loss_D: 0.9299 Loss_G: 2.7818 D(x): 0.7656 D(G(z)): 0.4220 / 0.0857
[4/5][550/1583] Loss_D: 0.4178 Loss_G: 3.2685 D(x): 0.8952 D(G(z)): 0.2436 / 0.0496
[4/5][600/1583] Loss_D: 0.8358 Loss_G: 1.6605 D(x): 0.5853 D(G(z)): 0.1576 / 0.2465
[4/5][650/1583] Loss_D: 2.2432 Loss_G: 2.5945 D(x): 0.9064 D(G(z)): 0.8167 / 0.1101
[4/5][700/1583] Loss_D: 0.5702 Loss_G: 2.5894 D(x): 0.7715 D(G(z)): 0.2206 / 0.0984
[4/5][750/1583] Loss_D: 1.3677 Loss_G: 1.1562 D(x): 0.3420 D(G(z)): 0.0918 / 0.3838
[4/5][800/1583] Loss_D: 0.5114 Loss_G: 2.1584 D(x): 0.7484 D(G(z)): 0.1617 / 0.1515
[4/5][850/1583] Loss_D: 1.0217 Loss_G: 0.9296 D(x): 0.4531 D(G(z)): 0.0982 / 0.4508
[4/5][900/1583] Loss_D: 0.9362 Loss_G: 2.2130 D(x): 0.7036 D(G(z)): 0.3690 / 0.1460
[4/5][950/1583] Loss_D: 0.5356 Loss_G: 1.9746 D(x): 0.7591 D(G(z)): 0.1908 / 0.1650
[4/5][1000/1583] Loss_D: 0.6062 Loss_G: 3.4775 D(x): 0.8733 D(G(z)): 0.3460 / 0.0390
[4/5][1050/1583] Loss_D: 0.7407 Loss_G: 1.4774 D(x): 0.5800 D(G(z)): 0.1078 / 0.2783
[4/5][1100/1583] Loss_D: 0.6590 Loss_G: 1.7840 D(x): 0.6090 D(G(z)): 0.0856 / 0.2158
[4/5][1150/1583] Loss_D: 1.1789 Loss_G: 4.3504 D(x): 0.8746 D(G(z)): 0.5924 / 0.0210
[4/5][1200/1583] Loss_D: 0.5922 Loss_G: 2.0222 D(x): 0.7280 D(G(z)): 0.1955 / 0.1713
[4/5][1250/1583] Loss_D: 0.6910 Loss_G: 1.3199 D(x): 0.6022 D(G(z)): 0.1046 / 0.3194
[4/5][1300/1583] Loss_D: 0.6726 Loss_G: 2.6413 D(x): 0.8492 D(G(z)): 0.3557 / 0.0972
[4/5][1350/1583] Loss_D: 0.7808 Loss_G: 3.7603 D(x): 0.8605 D(G(z)): 0.4220 / 0.0317
[4/5][1400/1583] Loss_D: 0.6173 Loss_G: 3.5769 D(x): 0.9113 D(G(z)): 0.3699 / 0.0392
[4/5][1450/1583] Loss_D: 0.4523 Loss_G: 2.5795 D(x): 0.7925 D(G(z)): 0.1731 / 0.0977
[4/5][1500/1583] Loss_D: 0.4865 Loss_G: 2.3129 D(x): 0.8465 D(G(z)): 0.2489 / 0.1228
[4/5][1550/1583] Loss_D: 1.0133 Loss_G: 4.2263 D(x): 0.8779 D(G(z)): 0.5241 / 0.0254
Results
最後に、結果を確認しましょう。ここでは、3つの異なる結果を見ていきます。
まず、トレーニング中に $D$ と $G$ の損失がどのように変化したかを確認します。
次に、各エポックのfixed_noiseによる $G$ の出力を視覚化します。
最後に、$G$ で生成した画像と、実際の画像を並べてみます。
損失とトレーニングの反復
以下は、D と G の損失とトレーニングの反復のプロットです。
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
Gの進行の視覚化
トレーニングの各エポックで、fixed_noiseで生成器の出力画像を保存しました。これで、Gのトレーニングの進行をアニメーションで視覚化できます。再生ボタンを押してアニメーションを開始します。(Qiitaではアニメーションを貼り付けられませんのでGIFでアップロードしています)
#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
HTML(ani.to_jshtml())
実画像と偽画像
最後に、いくつかの実際の画像と偽の画像を並べて見てみましょう。
# データローダから実際の画像のバッチを取得します
real_batch = next(iter(dataloader))
# 実際の画像をプロットします
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))
# 最後のエポックからの偽の画像をプロットします
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()
Where to Go Next
チュートリアルは以上ですが、次にできることがいくつかあります。
- より長くトレーニングし、結果が良くなるか確認します。
- このモデルを変更して別のデータセットを取得し、場合によっては画像のサイズとモデルアーキテクチャを変更します。
- ここで他のGANプロジェクトをチェックしてください。
- 音楽を生成するGANを作成します。
終わりに
今回は、GAN(敵対的生成ネットワーク)を学びました。
次回は「Audio I/O and Pre-Processing with torchaudio」を進めてみたいと思います。
履歴
2021/2/3 初版公開