背景
dcganで動物や人物の画像を生成する例が多く、他の例が少ないので、dcganで水中画像を生成してみました。
原理
DCGANはCNNとGANを結合するものである。これは、畳み込みネットワークを生成モデルに導入して無監督の訓練を行い、畳み込みネットワークの強い特徴抽出能力を利用してネットワークを生成する学習効果を高める。
code
dcgan network
import torch.nn as nn
# generator net Gを定義
class NetG(nn.Module):
def __init__(self, ngf, nz):
super(NetG, self).__init__()
self.layer1 = nn.Sequential(
nn.ConvTranspose2d(nz, ngf * 8, kernel_size=4, stride=1, padding=0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(inplace=True)
)
self.layer2 = nn.Sequential(
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(inplace=True)
)
self.layer3 = nn.Sequential(
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(inplace=True)
)
self.layer4 = nn.Sequential(
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(inplace=True)
)
self.layer5 = nn.Sequential(
nn.ConvTranspose2d(ngf, 3, 5, 3, 1, bias=False),
nn.Tanh()
)
# 前方伝播
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.layer5(out)
return out
# discriminator net Dを定義
class NetD(nn.Module):
def __init__(self, ndf):
super(NetD, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(3, ndf, kernel_size=5, stride=3, padding=1, bias=False),
nn.BatchNorm2d(ndf),
nn.LeakyReLU(0.2, inplace=True)
)
self.layer2 = nn.Sequential(
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True)
)
self.layer3 = nn.Sequential(
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True)
)
self.layer4 = nn.Sequential(
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True)
)
self.layer5 = nn.Sequential(
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
# 前方伝播
def forward(self,x):
out = self.layer1(x)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.layer5(out)
return out
train
import argparse
import torch
import torchvision
import torchvision.utils as vutils
import torch.nn as nn
from random import randint
from model import NetD, NetG
parser = argparse.ArgumentParser()
parser.add_argument('--batchSize', type=int, default=64)
parser.add_argument('--imageSize', type=int, default=96)
parser.add_argument('--nz', type=int, default=100, help='latent z vector')
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--epoch', type=int, default=1000, help='epoch')
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--data_path', default='data/', help='train data')
parser.add_argument('--outf', default='imgs/', help='output images and model checkpoints')
opt = parser.parse_args()
# GPUを使用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#画像読み込み
transforms = torchvision.transforms.Compose([
torchvision.transforms.Scale(opt.imageSize),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])
dataset = torchvision.datasets.ImageFolder(opt.data_path, transform=transforms)
dataloader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=opt.batchSize,
shuffle=True,
drop_last=True,
)
netG = NetG(opt.ngf, opt.nz).to(device)
netD = NetD(opt.ndf).to(device)
criterion = nn.BCELoss()
optimizerG = torch.optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
optimizerD = torch.optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
label = torch.FloatTensor(opt.batchSize)
real_label = 1
fake_label = 0
for epoch in range(1, opt.epoch + 1):
for i, (imgs,_) in enumerate(dataloader):
optimizerD.zero_grad()
imgs=imgs.to(device)
output = netD(imgs)
label.data.fill_(real_label)
label=label.to(device)
errD_real = criterion(output, label)
errD_real.backward()
label.data.fill_(fake_label)
noise = torch.randn(opt.batchSize, opt.nz, 1, 1)
noise=noise.to(device)
fake = netG(noise)
output = netD(fake.detach())
errD_fake = criterion(output, label)
errD_fake.backward()
errD = errD_fake + errD_real
optimizerD.step()
optimizerG.zero_grad()
label.data.fill_(real_label)
label = label.to(device)
output = netD(fake)
errG = criterion(output, label)
errG.backward()
optimizerG.step()
print('[%d/%d][%d/%d] Loss_D: %.3f Loss_G %.3f'
% (epoch, opt.epoch, i, len(dataloader), errD.item(), errG.item()))
vutils.save_image(fake.data,
'%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch),
normalize=True)
torch.save(netG.state_dict(), '%s/netG_%03d.pth' % (opt.outf, epoch))
torch.save(netD.state_dict(), '%s/netD_%03d.pth' % (opt.outf, epoch))
結果
epoch 100
epoch 300
epoch 500
epoch 800
epoch 1000
まとめ
生成された画像の品質が良くないです、しかもオーバーフィット現象が現れました、データセットの画像が少なすぎるかもしれないです。