61
49

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorchAdvent Calendar 2018

Day 11

5ステップでできるPyTorch - DCGAN

Last updated at Posted at 2018-12-11

概要

PyTorchを使って、以下の5ステップでDCGANを作成します。

  1. データの準備
  2. Generatorの作成
  3. Discriminatorの作成
  4. 訓練関数の作成
  5. DCGANの訓練スタート

当記事は、
DCGANの理論は他の方に任せて、簡単・シンプルなコードで、サクッと動かすことを目的としています。

コード・サンプルデータセットは**GitHub**に載せています。

1. データの準備

1.1 データのダウンロード・前処理

前処理済みのサンプルデータセット(sample_data)をGitHubに用意しているので、以下は読み飛ばしても大丈夫です。

データは、アニメのキャラクターの顔を集めたデータセットAnimeFace Character Datasetを使用します。

しかし、AnimeFace Character Datasetは顔よりも広めに切り取ってあるので、lbpcascade_animefaceを使用して、さらに実際に顔の部分だけを切り出し、画像サイズを64 x 64 に整形しておきます。

face_0_119_15.png   ==>   0.jpg

以上で、計4257枚のデータセットが出来上がりました。

1.2 データの読み込み

GitHubよりsample_dataをダウンロードしているテイで進めます。

.py
import torch
from torch import nn, optim
from torchvision.utils import save_image
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchvision import transforms, datasets
import tqdm
from statistics import mean


# === 1. データの読み込み ===
# datasetrの準備
dataset = datasets.ImageFolder("sample_data/",
    transform=transforms.Compose([
        transforms.ToTensor()
]))

batch_size = 64

# dataloaderの準備
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

2. Generatorの作成

潜在特徴100次元ベクトルzから、3(チャネル) x 64 x 64の画像を生成するモデルを作成します。

Generator ( z )  =>  画像 ( 3x64x64 )

.py
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(

            nn.ConvTranspose2d(100, 256, 4, 1, 0, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(32, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

3. Discriminatorの作成

3 x 64 x 64 の画像を、1次元のスカラーに変換するモデルを作成します。

Discriminator ( 画像 )  =>  1次元スカラー

.py
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            
            nn.Conv2d(3, 32, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(32, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(256, 1, 4, 1, 0, bias=False),
        )

    def forward(self, x):
        return self.main(x).squeeze()

4. 訓練関数の作成

GANには、GeneratorとDiscriminatorの2つのモデルがあり、それぞれ交互に訓練していきます。

少し長くなりますが、以下のような流れで訓練していきます。(コードを見た方が分かりやすいかも)

(1) 潜在特徴100次元ベクトルzから、Generatorを使用して偽画像を生成する
(2) 偽画像をDiscriminatorで識別し、偽画像を実画像と騙せるようにGeneratorを学習する
(3) 実画像をDiscriminatorで識別する
(4) 偽画像を偽画像、実画像を実画像と識別できるようにDiscriminatorを学習する

.py
model_G = Generator().to("cuda:0")
model_D = Discriminator().to("cuda:0")

params_G = optim.Adam(model_G.parameters(),
    lr=0.0002, betas=(0.5, 0.999))
params_D = optim.Adam(model_D.parameters(),
    lr=0.0002, betas=(0.5, 0.999))

# 潜在特徴100次元ベクトルz
nz = 100

# ロスを計算するときのラベル変数
ones = torch.ones(batch_size).to("cuda:0") # 正例 1
zeros = torch.zeros(batch_size).to("cuda:0") # 負例 0
loss_f = nn.BCEWithLogitsLoss()

# 途中結果の確認用の潜在特徴z
check_z = torch.randn(batch_size, nz, 1, 1).to("cuda:0")


# 訓練関数
def train_dcgan(model_G, model_D, params_G, params_D, data_loader):
    log_loss_G = []
    log_loss_D = []
    for real_img, _ in tqdm.tqdm(data_loader):
        batch_len = len(real_img)


        # == Generatorの訓練 ==
        # 偽画像を生成
        z = torch.randn(batch_len, nz, 1, 1).to("cuda:0")
        fake_img = model_G(z)

        # 偽画像の値を一時的に保存 => 注(1)
        fake_img_tensor = fake_img.detach()

        # 偽画像を実画像(ラベル1)と騙せるようにロスを計算
        out = model_D(fake_img)
        loss_G = loss_f(out, ones[: batch_len])
        log_loss_G.append(loss_G.item())

        # 微分計算・重み更新 => 注(2)
        model_D.zero_grad()
        model_G.zero_grad()
        loss_G.backward()
        params_G.step()


        # == Discriminatorの訓練 ==
        # sample_dataの実画像
        real_img = real_img.to("cuda:0")
        
        # 実画像を実画像(ラベル1)と識別できるようにロスを計算
        real_out = model_D(real_img)
        loss_D_real = loss_f(real_out, ones[: batch_len])

        # 計算省略 => 注(1)
        fake_img = fake_img_tensor

        # 偽画像を偽画像(ラベル0)と識別できるようにロスを計算
        fake_out = model_D(fake_img_tensor)
        loss_D_fake = loss_f(fake_out, zeros[: batch_len])

        # 実画像と偽画像のロスを合計
        loss_D = loss_D_real + loss_D_fake
        log_loss_D.append(loss_D.item())

        # 微分計算・重み更新 => 注(2)
        model_D.zero_grad()
        model_G.zero_grad()
        loss_D.backward()
        params_D.step()

    return mean(log_loss_G), mean(log_loss_D)

注(1):pytorchでは同じTensorに対して2回バックプロパゲーションを計算できないので、一時的に保存しておいたTensorを使用して、無駄な計算(偽画像の生成)を省略する

注(2):計算グラフがGeneratorとDiscriminatorの両方に関わっているので、両モデルの勾配を初期化してから、微分計算・重み更新を行う

5. DCGANの訓練スタート

その後、epoch数だけ訓練関数を実行します。

.py
for epoch in range(300):
    train_dcgan(model_G, model_D, params_G, params_D, data_loader)
    
    # 訓練途中のモデル・生成画像の保存
    if epoch % 10 == 0:
        torch.save(
            model_G.state_dict(),
            "Weight_Generator/G_{:03d}.pth".format(epoch),
            pickle_protocol=4)
        torch.save(
            model_D.state_dict(),
            "Weight_Discriminator/D_{:03d}.pth".format(epoch),
            pickle_protocol=4)

        generated_img = model_G(check_z)
        save_image(generated_img,
                   "Generated_Image/{:03d}.jpg".format(epoch))

訓練後Generatorが作った画像例

290.jpg

データやモデル、ハイパラなど、特にこだわっていない中でこの程度の出来でした。

参考

61
49
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
61
49

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?