■初めに
この記事ではpytorchを用いたDCGANの実装について解説を行っています。
実装したものに関しては下記のGitHubレポジトリにアップしてありますので、用いたい場合はクローンしてお使いください。
https://github.com/YusukeOhnishi/DCGAN/tree/main
■GANとは
簡単にGANについて説明を行います。
GAN(敵対性ネットワーク)ではネットワークを2つ作成を行います。1つは生成器、もう1つは識別器と呼ばれるものです。基本的に欲しいものとしては生成器の方ですが、2つを同時に学習させる必要があります。
学習の流れとしては、まず生成器にてインプットにノイズを与えて画像を作成させるということを行います。このようにして作成した偽物の画像と本物の画像を識別器のインプットとして入力し、アウトプットとしては本物か偽物かを判定させます。またこの時生成器では識別器が本物と勘違いするような画像を生成するように学習を行わせます。
これらを同時に学習させることで、生成器ではより本物に近い画像を作成できるようになり、分類器ではより正確に本物と偽物を区別できるようになります。
これをより数学的に書くと下記のようになります。
GANにて生成される画像は確率分布$p_g(x|z)$に従い生成されます。本来のリアルデータの確率分布を$p_r(x)$とすると、初期状態ではこれら2つの分布は全く異なっていますが学習を進める毎に近づいていきます。ここで確率分布同士が近づいたということを、どう表現するのかということが疑問として残ります。これを測る指標としてはKL divergenceというものがあります。このKL divergenceは2つの確率分布関数間の距離尺度を与えるものであり、下記のような式で与えられます。
D_{\mathrm{KL}}(P||Q)=\int_{-\infty}^\infty p(x)\log\frac{p(x)}{q(x)}\mathrm{d}x
この$D_{\mathrm{KL}}(P||Q)$が小さいほど関数は近いということを表します。GANでよく距離尺度として用いられるものはこのKL divergenceを対称化したJS divergenceというものです。意味合いとしては同じです。JS divergenceは下記の式にて与えられます。
D_{\mathrm{JS}}(P||Q)=\frac{1}{2}D_{\mathrm{KL}}(P||\frac{1}{2}(P+Q))+\frac{1}{2}D_{\mathrm{KL}}(Q||\frac{1}{2}(P+Q))
GANの場合あればこの式を書き下すと下記のようになります。
D_{\mathrm{JS}}(p_r||p_g)=frac{1}{2}\left[D_{\mathrm{KL}}\left(p_r||\frac{p_r+p_g}{2}\right)+D_{\mathrm{KL}}\left(p_g||\frac{p_r+p_g}{2}\right)\right]\sim \left[\int p_r\log\left(\frac{p_r}{p_r+p_g}\right)\mathrm{d}x+\int p_g\log\left(\frac{p_g}{p_r+p_g}\right)\mathrm{d}x\right]
ここで$D(x)$を識別器が$x$を本物と分類する確率、$G(z)$をノイズ$z$を生成器にて変換した画像として、$\frac{p_r}{p_r+p_g}\simeq D(x),\frac{p_g}{p_r+p_g}\simeq 1-D(G(x))$と計算することができるため、上式はそれぞれ$D(x)$の確率分布$p_r$による期待値と、$1-D(G(x))$の確率分布$p_g$による期待値の和の形とみることができるため、結局下記のように書くことができます。
D_{\mathrm{JS}}=(\mathrm{E}[\log D(x)]+\mathrm{E}[\log(1- D(G(z)))])
つまりこの式は本物のデータ$x$を与えた場合に真と判定する確率の期待値と、偽のデータ$G(z)$を与えた場合に偽と判定する確率の期待値をそれぞれ足したような表式になっています。(厳密には$log$を取ったものの期待値なので、単純な期待値同士の足し算とは意味合いが異なるが)
よって$G(z)$においてじゃ偽物を本物と判定するように学習させたいので後ろの項を小さくするように学習を行います。つまり、$G(z)$は上記のJS divergenceを最小化する方向に学習します。
逆に識別器では本物は本物、偽物は偽物として判定を行いたいためJS divergenceを最大化する方向に学習を行います。
■DCGAN(DeepConvolutional GAN)とは
DCGANとはGANの中で最も基本的なモデルとなります。DCGANのDCは深層畳み込みを表し、モデル中で畳み込み演算を用いています。
▶生成器
DCGANにおける生成器は上記の図のような構造をしています。
入力となる橙色の図形はランダムなノイズとなっています。これはチャネル数が100の$1×1$の構造になっており、これをチャネル数$512$の$3\times 3$画像$\rightarrow$チャネル数$256$の$7\times 7$画像$\rightarrow$チャネル数$128$の$14\times 14$画像$\rightarrow$ チャネル数$1$の$28\times 28$と引き延ばしていき最終的にチャネル数$1$、つまりモノクロの$28\times28$のサイズの画像を取得します。
また各立方体に関して色ごとに下記のような処理を行っています。
・赤:Conv Transpose + Batch Norm + ReLU
・青:Conv Transpose + Tanh
ここで出てきたConvolutional Transpose(転置畳み込み)についてはこの記事内では解説は行いませんので、下記のページ等を参考にしてください。
https://cvml-expertguide.net/terms/dl/layers/convolution/transposed-convolution/
▶識別器
DCGANにおける識別器は上記の図のような構造をしています。
入力としてはチャネル数$1$の$28\times28$の画像を入力し、これを畳み込みチャネル数$128$の$14\times14$画像$\rightarrow$チャネル数$256$の$7\times7$画像$\rightarrow$チャネル数$512$の$3\times3$画像$\rightarrow$チャネル数$1$の$1\times1$画像と変換していき最終的に得られたものをシグモイド関数にて$0\sim1$の値に変換して結果を得ます。
各立方体の色によって下記のように処理をしています。
・青:Conv + LeakyReLU
・赤:Conv + Batch Norm + LeakyReLU
・黄:Conv
ここでLeakyReLUに関して簡単に説明します。LeakyReLUはReLUを少し変形した形となっていて、ReLUの場合はxが負の領域で傾き0としていたところを、LeakyReLUの場合はxが負の領域で一定の傾きを持たせるという点が異なります。詳細については下記を参照ください。
https://nisshingeppo.com/ai/leaky-relu-function/
これを用いる理由としては損失関数の計算において誤差を上手く逆伝播させるためです。GANの損失関数では$D(G(z))$のような形が出てきますが、逆伝播する際には$D()$の部分から先に誤差を伝えていくということを行います。この際に傾きを$0$としてしまうと、$D()$の中にある$G(z)$に誤差が伝わらないことになります。そのため、$D()$、つまり識別器側では一定量の誤差を残しておきたいという意図があり、LeakyReLUを持ちます。
▶損失関数
上記の説明では損失関数として下記のような形を生成器では最小化、識別器では最大化するということを記載しました。
\newcommand{\max}{\mathop{\rm max}\limits}
\newcommand{\min}{\mathop{\rm min}\limits}
\min_G\max_D \left(E_{x\sim p_r}[\log D(x)]+E_{z\sim p_z}[\log(1-D(G(z)))]\right)
しかしこの関数を用いた場合、計算が収束するまでに多くの時間がかかり、上手くいかないことが経験的に示されています。そのため、よく用いられる手法としては、生成器側の損失関数として別の関数を用意するというものがあります。これを具体的に書くと下記のようになり、学習の際はそれぞれを最小化するように勾配の計算を行います。
\mathrm{Loss D:}-E_{x\sim p_r}[\log D(x)]-E_{z\sim p_z}[\log(1-D(G(z)))]\\
\mathrm{Loss G:}-E_{z\sim p_z}[\log(D(G(z)))]\\
■実装
データセットのダウンロードを行う必要があるため、下記を実行してデータセットのダウンロードを行います。ここではカレントディレクトリの./data/フォルダにダウンロードを行うようにしています。また合わせて必要なライブラリをインポートします。
#ライブラリのインポート
import matplotlib.pyplot as plt
import numpy as np
import tqdm
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.datasets as dset
import torchvision.utils as vutils
import torchvision.transforms as transforms
#data配下にMNISTのデータがダウンロードされる。
dataset=dset.MNIST("./data/",train=True,download=True)
▶ハイパーパラメータの定義
ハイパーパラメータの設定にはargparseというライブラリを用いて下記のように実装を行います。
import argparse
#parserを初期化
parser=argparse.ArgumentParser()
#parserに引数を設定していく
#エポック数の設定
parser.add_argument("--n_epoch",type=int,default=200)
#バッチサイズ
parser.add_argument("--batch_size",type=int,default=64)
#学習率
parser.add_argument("--lr",type=float,default=2e-4)
#生成器のチャネル数
parser.add_argument("--nch_g",type=int,default=128)
#識別器のチャネル数
parser.add_argument("--nch_d",type=int,default=128)
#ノイズの次元
parser.add_argument("--z_dim",type=int,default=100)
#Adam Optimizerのパラメータ
parser.add_argument("--beta1",type=float,default=0.5)
#optに作成したハイパーパラメータの組を渡す。(デフォルトで入れている値そのまま)
opt=parser.parse_args(args=[])
▶生成器の実装
生成器を作成していきます。その際図に沿って見ていくとわかりやすいので、上記に載せた図を再掲します。
各立方体の処理
・赤:Conv Transpose + Batch Norm + ReLU
・青:Conv Transpose + Tanh
class Generator(nn.Module):
def __init__(self,z_dim=100,ngf=128,nc=1):
super().__init__()
#図の立方体一つがconvtXとなる。
self.convt1=self.conv_trans_layers(z_dim,4*ngf,3,1,0)
self.convt2=self.conv_trans_layers(4*ngf,2*ngf,3,2,0)
self.convt3=self.conv_trans_layers(2*ngf,ngf,4,2,1)
#図の最後の立方体のみ構造が異なるので個別に作成
self.convt4=nn.Sequential(
nn.ConvTranspose2d(ngf,nc,4,2,1),
nn.Tanh()
)
#図の立方体のネットワークつまり、転置畳み込み + Batch Norm + Reluの構造を作成
@staticmethod
def conv_trans_layers(in_channels,out_channels,kernel_size,stride,padding):
net=nn.Sequential(
nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride,padding,bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
return net
#フォワード計算
def forward(self,x):
#1層目にインプットを与え、出力されたアウトプットを2層目のインプットにする。ということを4層目まで行う。
out=self.convt1(x)
out=self.convt2(out)
out=self.convt3(out)
out=self.convt4(out)
return out
▶識別器の実装
識別器を作成していきます。その際図に沿って見ていくとわかりやすいので、上記に載せた図を再掲します。
各立方体の処理
・青:Conv + LeakyReLU
・赤:Conv + Batch Norm + LeakyReLU
・黄:Conv
class Discriminator(nn.Module):
def __init__(self,nc=1,nd=128):
super().__init__()
#1層目は batch_normなし
self.conv1=self.conv_layers(nc,ndf,has_batch_norm=False)
#2,3層目は batch_normあり
self.conv2=self.conv_layers(ndf,2*ndf)
self.conv3=self.conv_layers(2*ndf,4*ndf,3,2,0)
#4層目はconvを行ったのちにsigmoid関数を挟む
self.conv4=nn.Sequential(
nn.Conv2d(4*ndf,1,3,1,0),
nn.Sigmoid()
)
#図の立方体のネットワークつまり、畳み込み + Batch Norm + LeakyReluの構造を作成
@staticmethod
def conv_layers(in_channels,out_channels,kernel_size=4,stride=2,padding=1,has_batch_norm=True):
layers=[
nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=False)
]
#batch_norm層ありを作る場合
if has_batch_norm:
layers.append(nn.BatchNorm2d(out_channels))
layers.append(nn.LeakyReLU(0.2,inplace=True))
#生成器の時は直接書いていたが、リストをあらかじめ作成しておき後からSequentialに渡すこともできる。
net=nn.Sequential(*layers)
return net
#フォワード計算
def forward(self,x):
#1層目にインプットを与え、出力されたアウトプットを2層目のインプットにする。ということを4層目まで行う。
out=self.conv1(x)
out=self.conv2(out)
out=self.conv3(out)
out=self.conv4(out)
return out
▶層の初期化関数と前処理の定義
続いて作成した層の初期化を行う関数を定義します。初期化されたパラメータは正規分布に従って、ランダムに設定するようにします。
def weight_init(m):
#クラス名(何の層なのか)を取得
classname=m.__class__.__name__
#畳み込み層の場合
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data,0,0.02)
#BatchNorm層の場合
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data,1,0.02)
nn.init.constant_(m.bias.data,0)
さらに前処理にて行う変換を定義します。個々ではテンソル配列への変換と正則化を前処理として入れておきます。
#変換の定義
#テンソル配列への変換、正則化を行う
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,),(0.5,))
])
▶訓練の実装
訓練プロセスの実装に移ります。まずは結果及びパラメータの保存先の作成、パラメータ保存用の関数、データローダーの作成、デバイスの設定を行います。
#結果、パラメータの保存先を指定
dir_path_1="./result"
dir_path_2="./params"
#ディレクトリがない場合は作成する
os.makedirs(dir_path_1,exist_ok=True)
os.makedirs(dir_path_2,exist_ok=True)
#パラメータ保存用の関数を定義
def save_params(file_path,epoch,netD,netG):
torch.save(
netG.state_dict(),
file_path+"g_{:04d}.pth".format(epoch)
)
torch.save(
netD.state_dict(),
file_path+"d_{:04d}.pth".format(epoch)
)
#データローダにデータを渡す
dataset=dset.MNIST(root="./data/",download=False,train=True,transform=transform)
dataloader=DataLoader(dataset=dataset,batch_size=opt.batch_size,shuffle=True)
#デバイスの設定
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
次に上記で定義した生成器、識別器のクラスからオブジェクトを作成して層の初期化を行います。
#生成器作成
netG=Generator(z_dim=opt.z_dim,ngf=opt.nch_g).to(device)
#層の初期化
netG.apply(weight_init)
print(netG)
#識別器作成
netD=Discriminator(nc=1,ndf=opt.nch_d).to(device)
#層の初期化
netD.apply(weight_init)
print(netD)
結果を表示すると下記のようになっており、想定通りのネットワークが形成されていることが確認できます。
Generator(
(convt1): Sequential(
(0): ConvTranspose2d(100, 512, kernel_size=(3, 3), stride=(1, 1), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(convt2): Sequential(
(0): ConvTranspose2d(512, 256, kernel_size=(3, 3), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(convt3): Sequential(
(0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(convt4): Sequential(
(0): ConvTranspose2d(128, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): Tanh()
)
)
Discriminator(
(conv1): Sequential(
(0): Conv2d(1, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
)
(conv2): Sequential(
(0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): LeakyReLU(negative_slope=0.2, inplace=True)
)
(conv3): Sequential(
(0): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): LeakyReLU(negative_slope=0.2, inplace=True)
)
(conv4): Sequential(
(0): Conv2d(512, 1, kernel_size=(3, 3), stride=(1, 1))
(1): Sigmoid()
)
)
続いて、損失関数としてBinaryCrossEntropyを与え、OptimizerにはAdamを用いることを定義します。
#損失関数(BinaryCrossEntropy)
criterion=nn.BCELoss()
#生成器のOptimizer(Adam)を定義
optimizerG=optim.Adam(netG.parameters(),lr=opt.lr,betas=(opt.beta1,0.999),weight_decay=1e-5)
#識別器のOptimizer(Adam)を定義
optimizerD=optim.Adam(netD.parameters(),lr=opt.lr,betas=(opt.beta1,0.999),weight_decay=1e-5)
準備が整ったので訓練についての実装を下記のように行います。
#損失関数を格納するリストを作成
lossesD=[]
lossesG=[]
#イテレーションごとの損失関数を格納するリストを作成
raw_lossesD=[]
raw_lossesG=[]
#エポックごとのループ処理
for epoch in range(opt.n_epoch):
#損失関数の値の初期化
running_lossD=0.0
running_lossG=0.0
#ミニバッチごとのループ処理
for i,(real_imgs,labels) in enumerate(tqdm.tqdm(dataloader,position=0)):
#画像をデバイスに送る
real_imgs=real_imgs.to(device)
#バッチサイズ指定(最後のループでは指定した以外のサイズになるので、このように指定している)
batch_size=real_imgs.size()[0]
#生成器に与えるノイズを定義
noise=torch.randn(batch_size,opt.z_dim,1,1).to(device)
#labelの値を定義する 本物の画像:1 偽の画像:0
shape=(batch_size,1,1,1)
label_real=torch.ones(shape).to(device)
label_fake=torch.zeros(shape).to(device)
#####################
#####識別器の学習#####
#####################
#勾配の初期化
netD.zero_grad()
#本物の画像の識別を行わせ、損失関数の値を取得する
output=netD(real_imgs)
lossD_real=criterion(output,label_real)
#偽の画像を生成器で作成し、これを識別器に識別を行わせ、損失関数の値を取得する
fake_imgs=netG(noise)
output=netD(fake_imgs.detach())
lossD_fake=criterion(output,label_fake)
#トータルの損失関数を取得
lossD=lossD_real+lossD_fake
#トータルの損失関数を元にバックワードさせ、optimizerを使い1ステップ最適化させる
lossD.backward()
optimizerD.step()
#####################
#####生成器の学習#####
#####################
#勾配の初期化
netG.zero_grad()
#偽の画像(fake_imgs)を識別器に渡し、本物と判定させたいので損失関数に本物の画像のラベルを渡す
output=netD(fake_imgs)
lossG=criterion(output,label_real)
#損失関数を元にバックワードさせ、optimizerを使い1ステップ最適化させる
lossG.backward()
optimizerG.step()
#####################
######損失の保存######
#####################
#ミニバッチごとの損失を足していき、エポックごとの合計損失を計算
running_lossD+=lossD.item()
running_lossG+=lossG.item()
#ミニバッチごとの損失を保存
raw_lossesD.append(lossD.item())
raw_lossesG.append(lossG.item())
#エポックごとの合計損失を平均化して保存
running_lossD/=len(dataloader)
running_lossG/=len(dataloader)
lossesD.append(running_lossD)
lossesG.append(running_lossG)
print("epoch: {}, lossD: {}, lossG: {}".format(epoch+1,running_lossD,running_lossG))
#####################
#####偽の画像表示#####
#####################
#24枚分を合わせて表示するように指定
grid_imgs=vutils.make_grid(fake_imgs[:24].detach())
grid_imgs_arr=grid_imgs.cpu().numpy()
plt.imshow(np.transpose(grid_imgs_arr,(1,2,0)))
plt.show()
########################
#偽の画像とパラメータ保存#
########################
#全て出力すると邪魔なので、10回毎と学習の最後のみに出力
if (epoch)%10==0:
vutils.save_image(fake_imgs,"./result/epoch_{}.jpg".format(epoch))
save_params("./params/",epoch,netD,netG)
elif epoch==(opt.n_epoch-1):
vutils.save_image(fake_imgs,"./result/epoch_last.jpg")
torch.save(netG.state_dict(),"./g_last.pth")
torch.save(netD.state_dict(),"./d_last.pth")
これを実行することで学習が進行します。学習の途中で生成器によって生成された画像を出力してこれを見ると下記のように画像が生成されています。初期の段階では若干ぼやけていたものが徐々にはっきりと文字とわかるようになり、形も数字と判断できるようになっています。
最後に保存していた損失関数をグラフに起こしてみることで学習がどのように進んでいったのかを確認したいと思います。
生成器と識別器の損失関数を同じグラフ上に表示すると下記のようになります。