初めに
はい、
Akira_0809です、
前回、ChatGPTと一緒に作成した画像生成AIを改良します、
text2imageにしていきます、
1を入力すると1の画像が出力されるようにしたい
前回
会話ログ
授業
数字を条件としてGANに入力し画像生成する
「条件付きGAN」(Conditional GAN, cGAN)
を使用する
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# データセットのロード
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
# モデルの定義
class ConditionalGenerator(nn.Module):
def __init__(self):
super(ConditionalGenerator, self).__init__()
self.label_embedding = nn.Embedding(10, 10) # 0-9の数字のエンベディング
self.main = nn.Sequential(
nn.Linear(110, 256), # 100次元のノイズ + 10次元のエンコードされたラベル
nn.ReLU(True),
nn.Linear(256, 512),
nn.ReLU(True),
nn.Linear(512, 1024),
nn.ReLU(True),
nn.Linear(1024, 784),
nn.Tanh()
)
def forward(self, noise, labels):
# ラベルをエンコードし、ノイズと結合
embedded_labels = self.label_embedding(labels)
input = torch.cat((noise, embedded_labels), -1)
return self.main(input).view(-1, 1, 28, 28)
class ConditionalDiscriminator(nn.Module):
def __init__(self):
super(ConditionalDiscriminator, self).__init__()
- self.label_embedding = nn.Embedding(10, 784) # 0-9の数字のエンベディングを画像と同じサイズに
+ self.label_embedding = nn.Embedding(10, 10)
self.main = nn.Sequential(
nn.Linear(784 + 784, 512), # 画像のピクセル情報 + 784次元のエンコードされたラベル
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img, labels):
# ラベルをエンコードし、画像と結合
embedded_labels = self.label_embedding(labels)
input = torch.cat((img.view(img.size(0), -1), embedded_labels), -1)
return self.main(input)
# モデルのインスタンス作成
generator = ConditionalGenerator()
discriminator = ConditionalDiscriminator()
# 損失関数と最適化関数の定義
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002)
# トレーニングループ
num_epochs = 100
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(dataloader):
# ラベルを使って条件付きトレーニングを行う
real_labels = torch.ones(images.size(0), 1)
fake_labels = torch.zeros(images.size(0), 1)
# ランダムノイズとランダムラベルを生成
noise = torch.randn(images.size(0), 100)
fake_images = generator(noise, labels)
# ディスクリミネーターのトレーニング
optimizer_d.zero_grad()
outputs = discriminator(images, labels)
d_loss_real = criterion(outputs, real_labels)
outputs = discriminator(fake_images.detach(), labels)
d_loss_fake = criterion(outputs, fake_labels)
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
optimizer_d.step()
# ジェネレーターのトレーニング
optimizer_g.zero_grad()
outputs = discriminator(fake_images, labels)
g_loss = criterion(outputs, real_labels)
g_loss.backward()
optimizer_g.step()
print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item()}, g_loss: {g_loss.item()}')
# 生成された画像の表示
import matplotlib.pyplot as plt
import numpy as np
# 指定した数字を生成する
def generate_and_show(number):
noise = torch.randn(1, 100) # 1つの100次元ノイズ
label = torch.tensor([number]) # 指定した数字のラベル
with torch.no_grad():
fake_image = generator(noise, label).reshape(28, 28)
fake_image = (fake_image + 1) / 2 # [-1, 1] から [0, 1] へ
fake_image = fake_image * 255 # [0, 1] から [0, 255] へ
fake_image = fake_image.numpy().astype("uint8")
plt.imshow(fake_image, cmap='gray')
plt.axis('off')
plt.show()
# 例: 数字5の画像を生成して表示
generate_and_show(5)
生成された画像
ChatGPTによるとランダムなパターンでの生成になってるとのこと
lossも大きかったし、学習がうまくいっていない可能性が高い、
もう一度コードを生成し、画像生成
一旦、仕組みを理解する
なるほど、
入力ベクトルの末尾に条件ベクトルを結合するのか、
他にも入力ノイズに情報を持たせたりも出来るのか、
今回は入力ベクトルの末尾に条件ベクトルを結合するのでやってみる
出来た!
5ぽい
ConditionalDisoriminatorの
self.label_embedding = nn.Embedding(10, 784)
これが
self.label_embedding = nn.Embedding(10, 10)
こう
なぜか画像のサイズに合わせたベクトルに変換してた、、
ConditionalGenerator
ラベル情報を追加
class ConditionalGenerator(nn.Module):
def __init__(self):
super(ConditionalGenerator, self).__init__()
self.label_embedding = nn.Embedding(10, 10) # 0-9の数字のエンベディング
self.main = nn.Sequential(
nn.Linear(110, 256), # 100次元のノイズ + 10次元のエンコードされたラベル
nn.ReLU(True),
nn.Linear(256, 512),
nn.ReLU(True),
nn.Linear(512, 1024),
nn.ReLU(True),
nn.Linear(1024, 784),
nn.Tanh()
)
def forward(self, noise, labels):
# ラベルをエンコードし、ノイズと結合
embedded_labels = self.label_embedding(labels)
input = torch.cat((noise, embedded_labels), -1)
return self.main(input).view(-1, 1, 28, 28)
self.label_embedding = nn.Embedding(10, 10)
ラベルをエンベディングするやつ
nn.Linear(110, 256)
100次元のノイズ + 10次元のエンコードされたラベル
embedded_labels = self.label_embedding(labels)
input = torch.cat((noise, embedded_labels), -1)
ここでラベルのエンコードと結合を行う
ConditionalDiscriminator
ラベル情報を追加
class ConditionalDiscriminator(nn.Module):
def __init__(self):
super(ConditionalDiscriminator, self).__init__()
self.label_embedding = nn.Embedding(10, 10) # ラベルを10次元ベクトルにエンコード
self.main = nn.Sequential(
nn.Linear(784 + 10, 512), # 784次元の画像 + 10次元のラベル
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img, labels):
embedded_labels = self.label_embedding(labels) # ラベルをエンベディングベクトルに変換
input = torch.cat((img.view(img.size(0), -1), embedded_labels), -1) # 画像とエンベディングを結合
return self.main(input)
ジェネレーターと同じですね、
トレーニングループ
num_epochs = 100
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(dataloader):
# 本物のラベルと偽物のラベルを準備
real_labels = torch.ones(images.size(0), 1)
fake_labels = torch.zeros(images.size(0), 1)
# ノイズとラベルを生成
noise = torch.randn(images.size(0), 100)
fake_images = generator(noise, labels)
# ディスクリミネーターのトレーニング
optimizer_d.zero_grad()
outputs = discriminator(images, labels)
d_loss_real = criterion(outputs, real_labels)
outputs = discriminator(fake_images.detach(), labels)
d_loss_fake = criterion(outputs, fake_labels)
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
optimizer_d.step()
# ジェネレーターのトレーニング
optimizer_g.zero_grad()
outputs = discriminator(fake_images, labels)
g_loss = criterion(outputs, real_labels)
g_loss.backward()
optimizer_g.step()
print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item()}, g_loss: {g_loss.item()}')
ラベルを活用して学習してます、
生成された画像の表示
def generate_and_show_all_digits():
# 0から9のラベルを作成
labels = torch.arange(0, 10).long()
# 各ラベルに対応するノイズを生成
noise = torch.randn(10, 100)
# 画像を生成
with torch.no_grad():
fake_images = generator(noise, labels).reshape(-1, 28, 28)
# 画像を表示
plt.figure(figsize=(10, 1))
for i in range(10):
plt.subplot(1, 10, i+1)
plt.imshow(fake_images[i], cmap='gray')
plt.title(str(i))
plt.axis('off')
plt.show()
# 0から9の数字を生成して表示
generate_and_show_all_digits()
いい感じ!
学習ごとの損失の変化について
学習セクションごとで損失が60だったり1だったりと大きく変わることがあった
ChatGPTによると
- 重みの初期化
- データシャッフル
- 学習率スケジュール
- モデルの不安定性
- ハードウェアの影響
等が挙げられた
割とGANの性質上、運ゲーみたいな所があった、
最後に
2年前は1mmも分からなかったが出来るようになってうれしい!!
ChatGPTのすごさをさらに体感した、
もっと頑張ります!