Conditional GAN と GAN の違い
GAN (Generative Adversarial Network) は敵対的生成ネットワークの略称で、この敵対的、という部分を理解すれば Conditional GAN の理解は難しくありません。
敵対的とは
GAN は画像の生成に際し、**生成器( Generator )と識別器( Discriminator )**というモジュールを用意し、この二つを競わせることによって本物っぽいでユニークな画像をつくることができます。
よく偽札の製造者と警察官の例で説明されます。
2014 年に元 Google Brain の Ian Goodfellow さんによって設計されました。
サンプコードをこちらに載せます。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 定数定義
latent_dim = 100 # ノイズの次元
image_size = 28 * 28 # MNISTの画像サイズ
batch_size = 64
epochs = 50
learning_rate = 0.0002
# データ準備
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Generator 定義
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, image_size),
nn.Tanh()
)
def forward(self, z):
return self.model(z)
# Discriminator 定義
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(image_size, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img):
return self.model(img)
# モデル初期化
generator = Generator()
discriminator = Discriminator()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)
# オプティマイザと損失関数
g_optimizer = optim.Adam(generator.parameters(), lr=learning_rate)
d_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate)
criterion = nn.BCELoss()
# 学習
for epoch in range(epochs):
for i, (real_imgs, _) in enumerate(data_loader):
# 本物の画像
real_imgs = real_imgs.view(-1, image_size).to(device)
real_labels = torch.ones((real_imgs.size(0), 1)).to(device)
fake_labels = torch.zeros((real_imgs.size(0), 1)).to(device)
# Discriminatorの学習
z = torch.randn(real_imgs.size(0), latent_dim).to(device)
fake_imgs = generator(z)
real_loss = criterion(discriminator(real_imgs), real_labels)
fake_loss = criterion(discriminator(fake_imgs.detach()), fake_labels)
d_loss = real_loss + fake_loss
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
# Generatorの学習
g_loss = criterion(discriminator(fake_imgs), real_labels)
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
print(f"Epoch [{epoch + 1}/{epochs}] D Loss: {d_loss.item():.4f} G Loss: {g_loss.item():.4f}")
# サンプル生成
z = torch.randn(16, latent_dim).to(device)
with torch.no_grad():
generated_imgs = generator(z).view(-1, 1, 28, 28).cpu()
# 画像を表示
import matplotlib.pyplot as plt
def show_images(images):
images = (images + 1) / 2 # [-1, 1] -> [0, 1]
grid = torch.cat([torch.cat([images[i * 4 + j] for j in range(4)], dim=2) for i in range(4)], dim=1)
plt.imshow(grid.permute(1, 2, 0).squeeze(), cmap="gray")
plt.axis("off")
plt.show()
show_images(generated_imgs)
Conditional GAN
Conditional GAN (cGAN) は条件付きの画像生成を行う GAN になります。
つまり、生成器と識別器の両方にこれらのラベルが割り当てられます。したがって、生成器は予想されるラベル出力に類似した出力のみを生成し、識別器は、生成された出力が本物か偽物かをチェックするとともに、画像が特定のラベルと一致するかどうかをチェックします。
このラベル化のメリットは次の通りです。
- 収束が速くなる
偽の画像が従うランダムな分布にも、何らかのパターンが見られるため - 生成器の出力を制御できる
参考
「ガン」っていう人と「ギャン」っていう人に分かれますよね!