0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

ConditionalGAN(Pytorch)

Last updated at Posted at 2022-05-20

概要

CondidionalGANは条件付き敵対的生成ネットワークと呼ばれます
通常のGANでは生成されるデータはランダムとなっています。
手書き数字の例でいうと0が生成されるか1が生成されるかなどはランダムに決定され、数字を指定しての作成は難しいです。
そこで、潜在変数と本物画像、偽物画像のそれぞれに現在はなんのデータに関しての学習を行っているかを教えながら学習することで、生成したいデータを指定することができるようになります

ラベリング

現在なんのデータに関しての学習を行っているかを教える方法は以下の通りです。

潜在変数

潜在変数にOne-Hotエンコーディングでラベリングしたベクトルを結合する
条件付ベクトル.png

生成画像、本物画像

One-Hotベクトルを画像サイズに拡大したクラス数分のOne-Hot画像をチャンネル方向に結合する。
One-Hot画像は、すべての要素が1の画像が一枚で残りはすべての要素が0の画像となる

条件付画像.png

全体像

本物画像や偽物画像、そして潜在変数に現在のラベル情報を追加しGANの学習を行うだけです

全体.png

ラベリングの実装

カテゴリカル変数のラベルをOne-Hot形式に変換

def onehot_encode(label, device, n_class=10):
    eye = torch.eye(n_class, device=device)
    return eye[label].view(-1, n_class, 1, 1)

処理は以下の通りになります。

  1. torch.eyeで単位行列の作成
  2. ラベルを指定することでインデックスに従ったone-hotベクトルの取得
  3. viewで形状の調整

onehot_encode.png

潜在変数とOne-Hoeベクトルの連結

def concat_noise_label(noise, label, device, n_class):
    oh_label = onehot_encode(label, device, n_class)
    return torch.cat((noise, oh_label), dim=1)

処理は以下の通りになります。

  1. one-hotベクトルの取得
  2. one-hotベクトルとノイズの結合

concat_noise_label.png

One-Hot画像の作成

def create_onehot_image(image, label, device, n_class):
    B,C,H,W = image.shape
    encode_label = onehot_encode(label, device)
    encode_label = encode_label.expand(B,n_class, H, W)
    return torch.cat((image, encode_label), dim=1)

処理は以下の通りになります。

  1. 入力画像のサイズ取得
  2. one-hotベクトルの取得
  3. expandでOne-Hotベクトルを指定したサイズ分拡大させる(One-Hot画像)
  4. One-Hot画像と入力画像の結合

concat_image_label.png

損失関数

$min_Gmax_DV(D,C) = E_{x~p_{data}(x)}[logD(x|y)] + E_{x~p_{z}(x)}[log(1-D(G(z|y)))]$

損失関数は以上の通りです。

通常のGAN違う点としては以下の通りです

  • D(x|y):識別器(D)の入力がラベル情報の条件付き(y)な画像(x)
  • G(z|y):生成器(G)の入力がラベル情報の条件付き(y)な潜在変数(x)

モデル構造

生成器

通常の生成器のモデルです

class Generator(nn.Module):
  def __init__(self, nz=100, nch_g=64, nch=1):
    super(Generator, self).__init__()
    self.layer1 = nn.Sequential(
        nn.ConvTranspose2d(nz, nch_g*4, 3, 1, 0),
        nn.BatchNorm2d(nch_g*4),
        nn.ReLU()
    )
    self.layer2 = nn.Sequential(
        nn.ConvTranspose2d(nch_g*4, nch_g*2, 3, 2, 0),
        nn.BatchNorm2d(nch_g*2),
        nn.ReLU()
    )
    self.layer3 = nn.Sequential(
        nn.ConvTranspose2d(nch_g*2, nch_g, 4, 2, 1),
        nn.BatchNorm2d(nch_g),
        nn.ReLU()
    )
    self.layer4 = nn.Sequential(
        nn.ConvTranspose2d(nch_g , nch , 4, 2, 1),
        nn.Tanh()
    )

  def forward(self, z):
    x = self.layer1(z)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)
    return x

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
   ConvTranspose2d-1            [-1, 512, 3, 3]         507,392
       BatchNorm2d-2            [-1, 512, 3, 3]           1,024
              ReLU-3            [-1, 512, 3, 3]               0
   ConvTranspose2d-4            [-1, 256, 7, 7]       1,179,904
       BatchNorm2d-5            [-1, 256, 7, 7]             512
              ReLU-6            [-1, 256, 7, 7]               0
   ConvTranspose2d-7          [-1, 128, 14, 14]         524,416
       BatchNorm2d-8          [-1, 128, 14, 14]             256
              ReLU-9          [-1, 128, 14, 14]               0
  ConvTranspose2d-10            [-1, 1, 28, 28]           2,049
             Tanh-11            [-1, 1, 28, 28]               0
================================================================
Total params: 2,215,553
Trainable params: 2,215,553
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.98
Params size (MB): 8.45
Estimated Total Size (MB): 9.43
----------------------------------------------------------------

識別器

通常の生成器と比較して
「入力の次元サイズ=画像のチャンネル数+ラベルの項目数」
となっている点が異なります

class Discriminator(nn.Module):
  def __init__(self, nch=1, nch_d=64):
    super(Discriminator, self).__init__()
    self.layer1 = nn.Sequential(
        nn.Conv2d(nch, nch_d, 4, 2, 1),
        nn.LeakyReLU(negative_slope=0.2)
    )
    self.layer2 = nn.Sequential(
        nn.Conv2d(nch_d, nch_d*2, 4, 2, 1),
        nn.BatchNorm2d(nch_d*2),
        nn.LeakyReLU(negative_slope=0.2)
    )
    self.layer3 = nn.Sequential(
        nn.Conv2d(nch_d*2, nch_d*4, 3, 2, 0),
        nn.BatchNorm2d(nch_d*4),
        nn.LeakyReLU(negative_slope=0.2)
    )
    self.layer4 = nn.Sequential(
        nn.Conv2d(nch_d*4, 1, 3, 1, 0),
        nn.Sigmoid()
    )

  def forward(self, x):
      x = self.layer1(x)
      x = self.layer2(x)
      x = self.layer3(x)
      x = self.layer4(x)
      return x.squeeze()
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 64, 14, 14]          11,328
         LeakyReLU-2           [-1, 64, 14, 14]               0
            Conv2d-3            [-1, 128, 7, 7]         131,200
       BatchNorm2d-4            [-1, 128, 7, 7]             256
         LeakyReLU-5            [-1, 128, 7, 7]               0
            Conv2d-6            [-1, 256, 3, 3]         295,168
       BatchNorm2d-7            [-1, 256, 3, 3]             512
         LeakyReLU-8            [-1, 256, 3, 3]               0
            Conv2d-9              [-1, 1, 1, 1]           2,305
          Sigmoid-10              [-1, 1, 1, 1]               0
================================================================
Total params: 440,769
Trainable params: 440,769
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.03
Forward/backward pass size (MB): 0.39
Params size (MB): 1.68
Estimated Total Size (MB): 2.10
----------------------------------------------------------------

学習

criterion = nn.BCELoss()

for epoch in range(epoch_num):
    for itr ,data in enumerate(train_dataloader):
        real_image = data[0].to(device) 
        real_label = data[1].to(device)

        # One-Hot画像の作成
        real_image_label = create_onehot_image(real_image, real_label, device, class_num) 

        # 潜在変数とOne-Hotベクトルを結合
        ## 潜在変数の作成
        sample_size = real_image.size(0)  
        noise = torch.randn(sample_size, 100, 1, 1, device=device)
        ## 適当なラベル作成
        fake_label = torch.randint(class_num, (sample_size,), dtype=torch.long,    device=device)
        ## 潜在変数とラベルをOne-Hotエンコーディングを行ったベクトルを結合
        fake_noise_label = concat_noise_label(noise, fake_label, device, class_num)        
        
        real_target = torch.full((sample_size,), 1., device=device)
        fake_target = torch.full((sample_size,), 0., device=device)

        # 識別器の学習
        discriminator.zero_grad()

        ## 本物画像に対する損失計算
        output = discriminator(real_image_label)
        errD_real = criterion(output, real_target)

        ## 偽物画像に対する損失計算
        fake_image = generator(fake_noise_label)
        fake_image_label = concat_image_label(fake_image, fake_label, device)   
        output = discriminator(fake_image_label.detach()) 
        errD_fake = criterion(output, fake_target)  

        ## 識別器全体の損失
        errD = errD_real + errD_fake 
        errD.backward() 

        optimizerD.step()  

        # 生成器の学習
        generator.zero_grad() 

        output = discriminator(fake_image_label) 
        errG = criterion(output, real_target)
        errG.backward()

        optimizerG.step()

補足(One-Hotエンコーディングとは)

分類問題などにおいて使用されるラベルのつけ方です。
正解値は1、それ以外は0になるようにします
例として、以下の5つの感情を予測する分類モデルを作成するとします

ラベル.png

あるデータを抜き出して「悲しみ」が正解の値に対してOne-Hotエンコーディングを行うと以下のようになります

ラベル正解.png

正解値の「悲しみ」は1でそれ以外は0になっているのが分かると思います

今度は、抜き出したデータに対して、分類モデルを適用します

その結果が以下の通りです

ラベル予測.png

正解のラベルのように0と1だけにはならず、0~1の間をとっています。
これは、その値の確信度として考えることができます。
One-Hotエンコーディングでラベリングすることでその値かもしれない確信度をとるようにできます。

参考サイト

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?