2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Conditional GAN と GAN の違い

GAN (Generative Adversarial Network) は敵対的生成ネットワークの略称で、この敵対的、という部分を理解すれば Conditional GAN の理解は難しくありません。

敵対的とは

GAN は画像の生成に際し、**生成器( Generator )識別器( Discriminator )**というモジュールを用意し、この二つを競わせることによって本物っぽいでユニークな画像をつくることができます。
よく偽札の製造者と警察官の例で説明されます。

image.png

生成器:

ランダムなノイズを実際のデータ (画像、音声など) に与え似たデータサンプルに変換します。ジェネレーターはランダムで無意味な出力を生成し始めますが、時間の経過とともに、識別器からのフィードバックを通じて学習し、学習データに似たデータサンプルを作成する能力を向上させます

識別器:

学習データからの実際のデータサンプルと、生成器によって生成された偽のサンプルを区別する分類器です。ジェネレータのパフォーマンスに関するフィードバックをジェネレータに提供します。識別器は時間の経過とともに改善され、実際のデータと偽のデータを区別しやすくなります

2014 年に元 Google Brain の Ian Goodfellow さんによって設計されました。

サンプコードをこちらに載せます。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 定数定義
latent_dim = 100  # ノイズの次元
image_size = 28 * 28  # MNISTの画像サイズ
batch_size = 64
epochs = 50
learning_rate = 0.0002

# データ準備
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Generator 定義
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, image_size),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)

# Discriminator 定義
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(image_size, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img)

# モデル初期化
generator = Generator()
discriminator = Discriminator()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)

# オプティマイザと損失関数
g_optimizer = optim.Adam(generator.parameters(), lr=learning_rate)
d_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate)
criterion = nn.BCELoss()

# 学習
for epoch in range(epochs):
    for i, (real_imgs, _) in enumerate(data_loader):
        # 本物の画像
        real_imgs = real_imgs.view(-1, image_size).to(device)
        real_labels = torch.ones((real_imgs.size(0), 1)).to(device)
        fake_labels = torch.zeros((real_imgs.size(0), 1)).to(device)

        # Discriminatorの学習
        z = torch.randn(real_imgs.size(0), latent_dim).to(device)
        fake_imgs = generator(z)

        real_loss = criterion(discriminator(real_imgs), real_labels)
        fake_loss = criterion(discriminator(fake_imgs.detach()), fake_labels)
        d_loss = real_loss + fake_loss

        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # Generatorの学習
        g_loss = criterion(discriminator(fake_imgs), real_labels)

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

    print(f"Epoch [{epoch + 1}/{epochs}]  D Loss: {d_loss.item():.4f}  G Loss: {g_loss.item():.4f}")

# サンプル生成
z = torch.randn(16, latent_dim).to(device)
with torch.no_grad():
    generated_imgs = generator(z).view(-1, 1, 28, 28).cpu()

# 画像を表示
import matplotlib.pyplot as plt

def show_images(images):
    images = (images + 1) / 2  # [-1, 1] -> [0, 1]
    grid = torch.cat([torch.cat([images[i * 4 + j] for j in range(4)], dim=2) for i in range(4)], dim=1)
    plt.imshow(grid.permute(1, 2, 0).squeeze(), cmap="gray")
    plt.axis("off")
    plt.show()

show_images(generated_imgs)

Conditional GAN

Conditional GAN (cGAN) は条件付きの画像生成を行う GAN になります。
つまり、生成器と識別器の両方にこれらのラベルが割り当てられます。したがって、生成器は予想されるラベル出力に類似した出力のみを生成し、識別器は、生成された出力が本物か偽物かをチェックするとともに、画像が特定のラベルと一致するかどうかをチェックします。

image.png

このラベル化のメリットは次の通りです。

  • 収束が速くなる
    偽の画像が従うランダムな分布にも、何らかのパターンが見られるため
  • 生成器の出力を制御できる

参考

「ガン」っていう人と「ギャン」っていう人に分かれますよね!

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?