今回すること
前に超解像度化を行う画像処理用のモデルを作成したので、次は画像の生成を行うモデルを作成してみることにしました。そこで、画像の生成と言えばGANかなと思い、その中でも比較的簡単なモデルであるDCGANを今回は実装、学習して見ようと考えています。
DCGANについて
GAN(Generative adversarial network)の思想は単純で、贋作の製作者とそれを見破る鑑定士の競争によってより高精度な贋作を作っていくというものです。したがってネットワークの構造は、
・適当なノイズを入力として画像を生成するGenerator
・画像を入力としてその真贋を判定するDiscriminator
の二つで基本的にできています。
細かい仕組みについては今さら聞けないGAN(1) 基本構造の理解こちらの方で勉強させていただきました。
DCGAN(Deep Convolutional Generative adversarial network)はその中でも逆畳み込み層を用いて画像を生成するモデルです。
実装
コードは以下の四つのファイルに分けて実装しました。
・ networks.py : ネットワーク構造について書いている
・ utils.py : Datasetやloss関数について書いている
・ train.py : モデルの学習を行う
・ generate.py : 学習したモデルを使って画像を生成する。
networks.py
ここではGeneratorとDiscriminatorの構造を定義します。今回はDCGANのため単純な構造となっています。
import torch
from torch import nn
class Generator(nn.Module):
def __init__(self, latent_size = 100):
super().__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(latent_size, 256, 4, 1, 0, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(32, 3, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, x):
return self.main(x)
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
nn.Conv2d(3, 32, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(32, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 1, 4, 1, 0, bias=False),
)
def forward(self, x):
return self.main(x).squeeze()
utils.py
ここでは学習の際に使いやすいようにデータセットを定義します。
import os
from torch.utils.data import Dataset
from torchvision import transforms
import torch
from PIL import Image
class Dcgan_Dataset(Dataset):
def __init__(self, root, datamode = "train", transform = transforms.ToTensor(), latent_size=100):
self.image_dir = os.path.join(root, datamode)
self.image_paths = [os.path.join(self.image_dir, name) for name in os.listdir(self.image_dir)]
self.data_length = len(self.image_paths)
self.transform = transform
self.latent_size = latent_size
def __len__(self):
return self.data_length
def __getitem__(self, index):
latent = torch.randn(size=(self.latent_size, 1, 1))
img_path = self.image_paths[index]
img = Image.open(img_path)
if not self.transform is None:
img = self.transform(img)
return latent, img
## train.py
ここでは学習を定義します。テストを行なっていることや、tensorboardに記録を行なっているせいで、少し長めになってしまいました。
import os
import argparse
from networks import Generator, Discriminator
from utils import Dcgan_Dataset
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid
from tqdm import tqdm
def main(opt):
torch.manual_seed(opt.seed)
torch.cuda.manual_seed(opt.seed)
# ----- Device Setting -----
if opt.gpu is True:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device("cpu")
print("Device :", device)
# ----- Dataset Setting -----
train_dataset = Dcgan_Dataset(opt.dataset, datamode="train",
transform=transforms.Compose([transforms.RandomHorizontalFlip(),
transforms.ToTensor()]))
test_dataset = Dcgan_Dataset(opt.dataset, datamode="test")
print("Training Dataset :", os.path.join(opt.dataset, "train"))
print("Testing Dataset :", os.path.join(opt.dataset, "test"))
# ----- DataLoader Setting -----
train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=opt.test_batch_size, shuffle=True)
print("batch_size :",opt.batch_size)
print("test_batch_size :",opt.test_batch_size)
# ----- Summary Writer Setting -----
train_writer = SummaryWriter(log_dir=os.path.join(opt.tensorboard, opt.exper))
test_writer = SummaryWriter(log_dir=os.path.join(opt.tensorboard, opt.exper + "_test"))
print("log directory :",os.path.join(opt.tensorboard, opt.exper))
print("log step :", opt.n_log_step)
# ----- Net Work Setting -----
latent_size = opt.latent_size
model_D = Discriminator()
model_G = Generator()
# resume
if opt.resume_epoch != 0:
model_D_path = os.path.join(opt.checkpoints_dir, opt.exper, "model_D_{}.pth".format(str(opt.resume_epoch)))
model_G_path = os.path.join(opt.checkpoints_dir, opt.exper, "model_G_{}.pth".format(str(opt.resume_epoch)))
model_G.load_state_dict(torch.load(model_G_path, map_location="cpu"))
model_D.load_state_dict(torch.load(model_D_path, map_location="cpu"))
model_D.to(device)
model_G.to(device)
model_D.train()
model_G.train()
# ロスを計算するときのラベル変数
ones = torch.ones(opt.batch_size).to(device) # 正例 1
zeros = torch.zeros(opt.batch_size).to(device) # 負例 0
val_latents = torch.randn(9, opt.latent_size, 1, 1).to(device)
loss_f = nn.BCEWithLogitsLoss()
optimizer_D = torch.optim.Adam(model_D.parameters(), lr=0.0002)
optimizer_G = torch.optim.Adam(model_G.parameters(), lr=0.0002)
print("Latent size :",opt.latent_size)
# ----- Training Loop -----
step = 0
for epoch in tqdm(range(opt.resume_epoch, opt.resume_epoch + opt.epoch)):
print("epoch :",epoch + 1,"/", opt.resume_epoch + opt.epoch)
# for latent, real_img in tqdm(train_loader):
for latent, real_img in train_loader:
step += 1
latent = latent.to(device)
real_img = real_img.to(device)
batch_len = len(real_img)
fake_img = model_G(latent)
pred_fake = model_D(fake_img)
loss_G = loss_f(pred_fake, ones[: batch_len])
model_D.zero_grad()
model_G.zero_grad()
loss_G.backward()
optimizer_G.step()
pred_real = model_D(real_img)
loss_D_real = loss_f(pred_real, ones[: batch_len])
fake_img = model_G(latent)
pred_fake = model_D(fake_img)
loss_D_fake = loss_f(pred_fake, zeros[: batch_len])
loss_D = loss_D_real + loss_D_fake
model_D.zero_grad()
model_G.zero_grad()
loss_D.backward()
optimizer_D.step()
if step % opt.n_log_step == 0:
# test step
model_G.eval()
model_D.eval()
test_d_losses = []
test_d_real_losses = []
test_d_fake_losses = []
test_g_losses = []
for test_latent, test_real_img in test_loader:
test_latent = test_latent.to(device)
test_real_img = test_real_img.to(device)
batch_len = len(test_latent)
test_pred_img = model_G(test_latent)
test_fake_g = model_D(test_pred_img)
test_g_loss = loss_f(test_fake_g, ones[: batch_len])
test_g_losses.append(test_g_loss.item())
test_fake_d = model_D(test_pred_img)
test_real_d = model_D(test_real_img)
test_d_real_loss = loss_f(test_real_d, ones[: batch_len])
test_d_fake_loss = loss_f(test_fake_d, zeros[: batch_len])
test_d_loss = test_d_real_loss + test_d_fake_loss
test_d_real_losses.append(test_d_real_loss.item())
test_d_fake_losses.append(test_d_fake_loss.item())
test_d_losses.append(test_d_loss.item())
# record process
test_g_loss = sum(test_g_losses)/len(test_g_losses)
test_d_loss = sum(test_d_losses)/len(test_d_losses)
test_d_real_loss = sum(test_d_real_losses)/len(test_d_real_losses)
test_d_fake_loss = sum(test_d_fake_losses)/len(test_d_fake_losses)
train_writer.add_scalar("loss/g_loss", loss_G.item(), step)
train_writer.add_scalar("loss/d_loss", loss_D.item(), step)
train_writer.add_scalar("loss/d_real_loss", loss_D_real.item(), step)
train_writer.add_scalar("loss/d_fake_loss", loss_D_fake.item(), step)
train_writer.add_scalar("loss/epoch", epoch + 1, step)
test_writer.add_scalar("loss/g_loss", test_g_loss, step)
test_writer.add_scalar("loss/d_loss", test_d_loss, step)
test_writer.add_scalar("loss/d_real_loss", test_d_real_loss, step)
test_writer.add_scalar("loss/d_fake_loss", test_d_fake_loss, step)
pred_img = model_G(val_latents)
grid_img = make_grid(pred_img, nrow=3, padding=0)
grid_img = grid_img.mul(0.5).add_(0.5)
train_writer.add_image("train/{}".format(epoch), grid_img, step)
model_D.train()
model_G.train()
if (epoch + 1) % opt.n_save_epoch == 0:
save_dir = os.path.join(opt.checkpoints_dir, opt.exper)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
model_g_path = os.path.join(save_dir, "model_G_{}.pth".format(str(epoch + 1)))
model_d_path = os.path.join(save_dir, "model_D_{}.pth".format(str(epoch + 1)))
torch.save(model_D.state_dict(), model_d_path)
torch.save(model_G.state_dict(), model_g_path)
print("save_model")
# save model
save_dir = os.path.join(opt.checkpoints_dir, opt.exper)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
model_g_path = os.path.join(save_dir, "model_G_{}.pth".format(str(opt.epoch)))
model_d_path = os.path.join(save_dir, "model_D_{}.pth".format(str(opt.epoch)))
torch.save(model_D.state_dict(), model_d_path)
torch.save(model_G.state_dict(), model_g_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", default="../dataset/face_crop_img")
parser.add_argument("--checkpoints_dir", default="../checkpoints")
parser.add_argument("--exper", default="dcgan")
parser.add_argument("--tensorboard", default="../tensorboard")
parser.add_argument("--gpu", action="store_true", default=False)
parser.add_argument("--epoch", type=int, default=10)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--test_batch_size", type=int, default=4)
parser.add_argument("--n_log_step", type=int, default=10)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--n_save_epoch", type=int, default=10)
parser.add_argument("--latent_size", type=int, default=100)
# resume
parser.add_argument("--resume_epoch", type=int, default=0)
opt = parser.parse_args()
main(opt)
generate.py
ここではモデルを読み込んで、画像を生成する動作を定義します。
import os
import argparse
from networks import Generator
import numpy as np
from tqdm import tqdm
import torch
from torchvision.utils import save_image
def generate(latents, model_G, width, height, save_path):
assert len(latents) == width * height
pred_images = model_G(latents)
save_image(pred_images, save_path, nrow=width)
def generate_process(opt):
# ----- Device Setting -----
device = torch.device("cpu")
# ----- Output Setting -----
img_output_dir = opt.output_dir
if not os.path.exists(img_output_dir):
os.mkdir(img_output_dir)
if not os.path.exists(os.path.join(img_output_dir, "images")):
os.mkdir(os.path.join(img_output_dir, "images"))
if opt.save_latent is True:
if not os.path.exists(os.path.join(img_output_dir, "latents")):
os.mkdir(os.path.join(img_output_dir, "latents"))
print("Output :", img_output_dir)
# ----- Model Loading -----
print("Use model :", opt.model)
model_g = Generator()
model_g.load_state_dict(torch.load(opt.model, map_location="cpu"))
model_g.to(device)
model_g.eval()
if opt.mode == "normal":
latents = [torch.randn(size=(opt.width * opt.height, opt.latent_size, 1, 1) for i in range(opt.n_img)]
elif opt.mode == "use_latent":
assert opt.latent_dir != "None", "latent source directory is not set"
latent_paths = [os.path.join(opt.latent_dir, name) for name in os.listdir(opt.latent_dir)]
latents = [torch.from_numpy(np.load(path)) for path in latent_paths]
elif opt.mode == "inter":
latent_start = torch.from_numpy(np.load(opt.start_latent))
latent_end = torch.from_numpy(np.load(opt.end_latent))
alphas = [float(n / opt.latent_num) for n in range(opt.n_img)]
latents = [alpha * latent_end + (1 - alpha) * latent_start for alpha in alphas]
print("Generate image num :", len(latents))
# ----- Generate Step -----
print("Start Generate Process")
for index,latent in tqdm(enumerate(latents)):
img_path = os.path.join(img_output_dir, "images", str(index + 1) + ".png")
generate(latent, model_g, opt.width, opt.height, img_path)
if opt.save_latent:
latent_path = os.path.join(img_output_dir, "latents", str(index + 1) + ".npy")
np.save(latent_path, latent.numpy())
print("Finish Generate Process")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="../checkpoints/dcgan_test/model_G_1000.pth")
parser.add_argument("--output_dir", default="./result")
parser.add_argument("--save_latent", action="store_true", default=False)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--width", type=int, default=1)
parser.add_argument("--height", type=int, default=1)
parser.add_argument("--mode", choice=["normal", "use_latent", "inter"], default="normal")
parser.add_argument("--n_img", type=int, default=1)
# normal generation
# use latent
parser.add_argument("--latent_dir", default="None")
# intermediate vectl
parser.add_argument("--generate_inter", action="store_true", default=False)
parser.add_argument("--start_latent", type=str, default=".")
parser.add_argument("--end_latent", type=str, default=".")
opt = parser.parse_args()
generate(opt)
学習結果
データセットは18000枚のアニメ顔画像で行い、学習はGoogle Colaboratoryを使って行いました。以下が学習の結果です。
10epoch
20epoch
30epoch
40epoch
50epoch
100epoch
200epoch
400epoch
600epoch
800epoch
1000epoch
最後に
今回は、DCGANの実装と学習を行いました。やはり生成画像の精度はそこまで高くはなりませんでした。データセットもアノーテーションを行なっていないので、質が低くなってしまうようです。また、今回は64×64の画像の生成のため、そこまで難しくはありませんでしたが、これがもっと高解像度になってくると、DCGANでは学習時間が膨大になってくると思うので、他のモデルについても勉強しておこうとおもいます。