2
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorchでDCGANを実装・学習してみた

Last updated at Posted at 2020-05-01

今回すること

 前に超解像度化を行う画像処理用のモデルを作成したので、次は画像の生成を行うモデルを作成してみることにしました。そこで、画像の生成と言えばGANかなと思い、その中でも比較的簡単なモデルであるDCGANを今回は実装、学習して見ようと考えています。

DCGANについて

 GAN(Generative adversarial network)の思想は単純で、贋作の製作者とそれを見破る鑑定士の競争によってより高精度な贋作を作っていくというものです。したがってネットワークの構造は、
・適当なノイズを入力として画像を生成するGenerator
・画像を入力としてその真贋を判定するDiscriminator
の二つで基本的にできています。
 細かい仕組みについては今さら聞けないGAN(1) 基本構造の理解こちらの方で勉強させていただきました。
 DCGAN(Deep Convolutional Generative adversarial network)はその中でも逆畳み込み層を用いて画像を生成するモデルです。

実装

 コードは以下の四つのファイルに分けて実装しました。
・ networks.py : ネットワーク構造について書いている
・ utils.py : Datasetやloss関数について書いている
・ train.py : モデルの学習を行う
・ generate.py : 学習したモデルを使って画像を生成する。

networks.py

 ここではGeneratorとDiscriminatorの構造を定義します。今回はDCGANのため単純な構造となっています。

networks.py

import torch
from torch import nn

class Generator(nn.Module):
    def __init__(self, latent_size = 100):
        super().__init__()
        self.main = nn.Sequential(

            nn.ConvTranspose2d(latent_size, 256, 4, 1, 0, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(32, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(

            nn.Conv2d(3, 32, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(32, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 1, 4, 1, 0, bias=False),
        )

    def forward(self, x):
        return self.main(x).squeeze()

utils.py

ここでは学習の際に使いやすいようにデータセットを定義します。

utils.py
import os

from torch.utils.data import Dataset
from torchvision import transforms
import torch
from PIL import Image

class Dcgan_Dataset(Dataset):
    def __init__(self, root, datamode = "train", transform = transforms.ToTensor(), latent_size=100):

        self.image_dir = os.path.join(root, datamode)
        self.image_paths = [os.path.join(self.image_dir, name) for name in os.listdir(self.image_dir)]
        self.data_length = len(self.image_paths)

        self.transform = transform
        self.latent_size = latent_size

    def __len__(self):
        return self.data_length
    
    def __getitem__(self, index):
        latent = torch.randn(size=(self.latent_size, 1, 1))
        img_path = self.image_paths[index]
        img = Image.open(img_path)

        if not self.transform is None:
            img = self.transform(img)
        
        return latent, img


## train.py

ここでは学習を定義します。テストを行なっていることや、tensorboardに記録を行なっているせいで、少し長めになってしまいました。


import os
import argparse

from networks import Generator, Discriminator
from utils import Dcgan_Dataset

import numpy as np

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid

from tqdm import tqdm



def main(opt):
    torch.manual_seed(opt.seed)
    torch.cuda.manual_seed(opt.seed)
    # ----- Device Setting -----
    if opt.gpu is True:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device("cpu")
    print("Device :", device)

    # ----- Dataset Setting -----
    train_dataset = Dcgan_Dataset(opt.dataset, datamode="train",
                                  transform=transforms.Compose([transforms.RandomHorizontalFlip(),
                                                                transforms.ToTensor()]))
    
    test_dataset = Dcgan_Dataset(opt.dataset, datamode="test")

    print("Training Dataset :", os.path.join(opt.dataset, "train"))
    print("Testing Dataset :", os.path.join(opt.dataset, "test"))

    # ----- DataLoader Setting -----
    train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=opt.test_batch_size, shuffle=True)

    print("batch_size :",opt.batch_size)
    print("test_batch_size :",opt.test_batch_size)

    # ----- Summary Writer Setting -----
    train_writer = SummaryWriter(log_dir=os.path.join(opt.tensorboard, opt.exper))
    test_writer = SummaryWriter(log_dir=os.path.join(opt.tensorboard, opt.exper + "_test"))

    print("log directory :",os.path.join(opt.tensorboard, opt.exper))
    print("log step :", opt.n_log_step)

    # ----- Net Work Setting -----
    latent_size = opt.latent_size
    model_D = Discriminator()
    model_G = Generator()

    # resume
    if opt.resume_epoch != 0:
        model_D_path = os.path.join(opt.checkpoints_dir, opt.exper, "model_D_{}.pth".format(str(opt.resume_epoch)))
        model_G_path = os.path.join(opt.checkpoints_dir, opt.exper, "model_G_{}.pth".format(str(opt.resume_epoch)))

        model_G.load_state_dict(torch.load(model_G_path, map_location="cpu"))
        model_D.load_state_dict(torch.load(model_D_path, map_location="cpu"))

    model_D.to(device)
    model_G.to(device)
    model_D.train()
    model_G.train()

    # ロスを計算するときのラベル変数
    ones = torch.ones(opt.batch_size).to(device) # 正例 1
    zeros = torch.zeros(opt.batch_size).to(device) # 負例 0

    val_latents = torch.randn(9, opt.latent_size, 1, 1).to(device)
    loss_f = nn.BCEWithLogitsLoss()

    optimizer_D = torch.optim.Adam(model_D.parameters(), lr=0.0002)
    optimizer_G = torch.optim.Adam(model_G.parameters(), lr=0.0002)

    print("Latent size :",opt.latent_size)

    # ----- Training Loop -----
    step = 0
    for epoch in tqdm(range(opt.resume_epoch, opt.resume_epoch + opt.epoch)):
        print("epoch :",epoch + 1,"/", opt.resume_epoch + opt.epoch)

        # for latent, real_img in tqdm(train_loader):
        for latent, real_img in train_loader:
            step += 1
            latent = latent.to(device)
            real_img = real_img.to(device)
            batch_len = len(real_img)

            fake_img = model_G(latent)

            pred_fake = model_D(fake_img)
            loss_G = loss_f(pred_fake, ones[: batch_len])

            model_D.zero_grad()
            model_G.zero_grad()
            loss_G.backward()
            optimizer_G.step()

            pred_real = model_D(real_img)
            loss_D_real = loss_f(pred_real, ones[: batch_len])
            fake_img = model_G(latent)
            pred_fake = model_D(fake_img)
            loss_D_fake = loss_f(pred_fake, zeros[: batch_len])
            loss_D = loss_D_real + loss_D_fake

            model_D.zero_grad()
            model_G.zero_grad()
            loss_D.backward()
            optimizer_D.step()

            if step % opt.n_log_step == 0:
                # test step 
                model_G.eval()
                model_D.eval()
                test_d_losses = []
                test_d_real_losses = []
                test_d_fake_losses = []
                test_g_losses = []
                for test_latent, test_real_img in test_loader:
                    test_latent = test_latent.to(device)
                    test_real_img = test_real_img.to(device)
                    batch_len = len(test_latent)
                    test_pred_img = model_G(test_latent)
                    test_fake_g = model_D(test_pred_img)
                    test_g_loss = loss_f(test_fake_g, ones[: batch_len])

                    test_g_losses.append(test_g_loss.item())

                    test_fake_d = model_D(test_pred_img)
                    test_real_d = model_D(test_real_img)
                    test_d_real_loss = loss_f(test_real_d, ones[: batch_len])
                    test_d_fake_loss = loss_f(test_fake_d, zeros[: batch_len])
                    test_d_loss = test_d_real_loss + test_d_fake_loss

                    test_d_real_losses.append(test_d_real_loss.item())
                    test_d_fake_losses.append(test_d_fake_loss.item())
                    test_d_losses.append(test_d_loss.item())
                
                # record process
                test_g_loss = sum(test_g_losses)/len(test_g_losses)
                test_d_loss = sum(test_d_losses)/len(test_d_losses)
                test_d_real_loss = sum(test_d_real_losses)/len(test_d_real_losses)
                test_d_fake_loss = sum(test_d_fake_losses)/len(test_d_fake_losses)


                train_writer.add_scalar("loss/g_loss", loss_G.item(), step)
                train_writer.add_scalar("loss/d_loss", loss_D.item(), step)
                train_writer.add_scalar("loss/d_real_loss", loss_D_real.item(), step)
                train_writer.add_scalar("loss/d_fake_loss", loss_D_fake.item(), step)
                train_writer.add_scalar("loss/epoch", epoch + 1, step)

                test_writer.add_scalar("loss/g_loss", test_g_loss, step)
                test_writer.add_scalar("loss/d_loss", test_d_loss, step)
                test_writer.add_scalar("loss/d_real_loss", test_d_real_loss, step)
                test_writer.add_scalar("loss/d_fake_loss", test_d_fake_loss, step)

                pred_img = model_G(val_latents)
                grid_img = make_grid(pred_img, nrow=3, padding=0)
                grid_img = grid_img.mul(0.5).add_(0.5)

                train_writer.add_image("train/{}".format(epoch), grid_img, step)

                model_D.train()
                model_G.train()
                
        if (epoch + 1) % opt.n_save_epoch == 0:
            save_dir = os.path.join(opt.checkpoints_dir, opt.exper)

            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            model_g_path = os.path.join(save_dir, "model_G_{}.pth".format(str(epoch + 1)))
            model_d_path = os.path.join(save_dir, "model_D_{}.pth".format(str(epoch + 1)))
            torch.save(model_D.state_dict(), model_d_path)
            torch.save(model_G.state_dict(), model_g_path)

            print("save_model")

    # save model
    save_dir = os.path.join(opt.checkpoints_dir, opt.exper)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    model_g_path = os.path.join(save_dir, "model_G_{}.pth".format(str(opt.epoch)))
    model_d_path = os.path.join(save_dir, "model_D_{}.pth".format(str(opt.epoch)))
    torch.save(model_D.state_dict(), model_d_path)
    torch.save(model_G.state_dict(), model_g_path)
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", default="../dataset/face_crop_img")
    parser.add_argument("--checkpoints_dir", default="../checkpoints")
    parser.add_argument("--exper", default="dcgan")
    parser.add_argument("--tensorboard", default="../tensorboard")
    parser.add_argument("--gpu", action="store_true", default=False)
    parser.add_argument("--epoch", type=int, default=10)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--test_batch_size", type=int, default=4)
    parser.add_argument("--n_log_step", type=int, default=10)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--n_save_epoch", type=int, default=10)

    parser.add_argument("--latent_size", type=int, default=100)

    # resume
    parser.add_argument("--resume_epoch", type=int, default=0)

    opt = parser.parse_args()
    main(opt)

generate.py

ここではモデルを読み込んで、画像を生成する動作を定義します。

generate.py
import os
import argparse

from networks import Generator

import numpy as np
from tqdm import tqdm
import torch
from torchvision.utils import save_image


def generate(latents, model_G, width, height, save_path):
    assert len(latents) == width * height
    pred_images = model_G(latents)
    save_image(pred_images, save_path, nrow=width)

def generate_process(opt):

    # ----- Device Setting -----
    device = torch.device("cpu")

    # ----- Output Setting -----

    img_output_dir = opt.output_dir
    if not os.path.exists(img_output_dir):
        os.mkdir(img_output_dir)

    if not os.path.exists(os.path.join(img_output_dir, "images")):
        os.mkdir(os.path.join(img_output_dir, "images"))
    
    if opt.save_latent is True:
        if not os.path.exists(os.path.join(img_output_dir, "latents")):
            os.mkdir(os.path.join(img_output_dir, "latents"))
    
    print("Output :", img_output_dir)

    # ----- Model Loading -----
    print("Use model :", opt.model)
    model_g = Generator()
    model_g.load_state_dict(torch.load(opt.model, map_location="cpu"))
    model_g.to(device)

    model_g.eval()

    if opt.mode == "normal":
        latents = [torch.randn(size=(opt.width * opt.height, opt.latent_size, 1, 1) for i in range(opt.n_img)]
    
    elif opt.mode == "use_latent":
        assert opt.latent_dir != "None", "latent source directory is not set"
        latent_paths = [os.path.join(opt.latent_dir, name) for name in os.listdir(opt.latent_dir)]
        latents = [torch.from_numpy(np.load(path)) for path in latent_paths]

    elif opt.mode == "inter":
        latent_start = torch.from_numpy(np.load(opt.start_latent))
        latent_end = torch.from_numpy(np.load(opt.end_latent))
        alphas = [float(n / opt.latent_num) for n in range(opt.n_img)]
        latents = [alpha * latent_end + (1 - alpha) * latent_start for alpha in alphas]

    print("Generate image num :", len(latents))

    # ----- Generate Step -----
    print("Start Generate Process")
    for index,latent in tqdm(enumerate(latents)):
        img_path = os.path.join(img_output_dir, "images", str(index + 1) + ".png")
        generate(latent, model_g, opt.width, opt.height, img_path)

        if opt.save_latent:
            latent_path = os.path.join(img_output_dir, "latents", str(index + 1) + ".npy")
            np.save(latent_path, latent.numpy())

    print("Finish Generate Process")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", default="../checkpoints/dcgan_test/model_G_1000.pth")
    parser.add_argument("--output_dir", default="./result")
    parser.add_argument("--save_latent", action="store_true", default=False)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--width", type=int, default=1)
    parser.add_argument("--height", type=int, default=1)
    parser.add_argument("--mode", choice=["normal", "use_latent", "inter"], default="normal")
    parser.add_argument("--n_img", type=int, default=1)

    # normal generation
    # use latent
    parser.add_argument("--latent_dir", default="None")

    # intermediate vectl
    parser.add_argument("--generate_inter", action="store_true", default=False)
    parser.add_argument("--start_latent", type=str, default=".")
    parser.add_argument("--end_latent", type=str, default=".")


    opt = parser.parse_args()
    generate(opt)
    

学習結果

 データセットは18000枚のアニメ顔画像で行い、学習はGoogle Colaboratoryを使って行いました。以下が学習の結果です。

10epoch

10.png

20epoch

20.png

30epoch

30.png

40epoch

40.png

50epoch

50.png

100epoch

100.png

200epoch

200.png

400epoch

400.png

600epoch

600.png

800epoch

800.png

1000epoch

1000.png

最後に

 今回は、DCGANの実装と学習を行いました。やはり生成画像の精度はそこまで高くはなりませんでした。データセットもアノーテーションを行なっていないので、質が低くなってしまうようです。また、今回は64×64の画像の生成のため、そこまで難しくはありませんでしたが、これがもっと高解像度になってくると、DCGANでは学習時間が膨大になってくると思うので、他のモデルについても勉強しておこうとおもいます。
 

2
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?