1
1

初めに

はい、
Akira_0809です、
前回、ChatGPTと一緒に作成した画像生成AIを改良します、
text2imageにしていきます、
1を入力すると1の画像が出力されるようにしたい

前回

会話ログ

授業

数字を条件としてGANに入力し画像生成する

「条件付きGAN」(Conditional GAN, cGAN)

を使用する

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# データセットのロード
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# モデルの定義
class ConditionalGenerator(nn.Module):
    def __init__(self):
        super(ConditionalGenerator, self).__init__()
        self.label_embedding = nn.Embedding(10, 10)  # 0-9の数字のエンベディング
        self.main = nn.Sequential(
            nn.Linear(110, 256),  # 100次元のノイズ + 10次元のエンコードされたラベル
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        # ラベルをエンコードし、ノイズと結合
        embedded_labels = self.label_embedding(labels)
        input = torch.cat((noise, embedded_labels), -1)
        return self.main(input).view(-1, 1, 28, 28)

class ConditionalDiscriminator(nn.Module):
    def __init__(self):
        super(ConditionalDiscriminator, self).__init__()
-        self.label_embedding = nn.Embedding(10, 784)  # 0-9の数字のエンベディングを画像と同じサイズに
+        self.label_embedding = nn.Embedding(10, 10)
        self.main = nn.Sequential(
            nn.Linear(784 + 784, 512),  # 画像のピクセル情報 + 784次元のエンコードされたラベル
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        # ラベルをエンコードし、画像と結合
        embedded_labels = self.label_embedding(labels)
        input = torch.cat((img.view(img.size(0), -1), embedded_labels), -1)
        return self.main(input)


# モデルのインスタンス作成
generator = ConditionalGenerator()
discriminator = ConditionalDiscriminator()

# 損失関数と最適化関数の定義
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002)

# トレーニングループ
num_epochs = 100
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(dataloader):
        # ラベルを使って条件付きトレーニングを行う
        real_labels = torch.ones(images.size(0), 1)
        fake_labels = torch.zeros(images.size(0), 1)

        # ランダムノイズとランダムラベルを生成
        noise = torch.randn(images.size(0), 100)
        fake_images = generator(noise, labels)

        # ディスクリミネーターのトレーニング
        optimizer_d.zero_grad()
        outputs = discriminator(images, labels)
        d_loss_real = criterion(outputs, real_labels)
        outputs = discriminator(fake_images.detach(), labels)
        d_loss_fake = criterion(outputs, fake_labels)
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_d.step()

        # ジェネレーターのトレーニング
        optimizer_g.zero_grad()
        outputs = discriminator(fake_images, labels)
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        optimizer_g.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item()}, g_loss: {g_loss.item()}')


# 生成された画像の表示
import matplotlib.pyplot as plt
import numpy as np

# 指定した数字を生成する
def generate_and_show(number):
    noise = torch.randn(1, 100)  # 1つの100次元ノイズ
    label = torch.tensor([number])  # 指定した数字のラベル
    with torch.no_grad():
        fake_image = generator(noise, label).reshape(28, 28)
        fake_image = (fake_image + 1) / 2  # [-1, 1] から [0, 1] へ
        fake_image = fake_image * 255  # [0, 1] から [0, 255] へ
        fake_image = fake_image.numpy().astype("uint8")
        plt.imshow(fake_image, cmap='gray')
        plt.axis('off')
        plt.show()

# 例: 数字5の画像を生成して表示
generate_and_show(5)

生成された画像

「5」には見えない、、
ダウンロード (2).png

ChatGPTによるとランダムなパターンでの生成になってるとのこと
lossも大きかったし、学習がうまくいっていない可能性が高い、

もう一度コードを生成し、画像生成

うーん
epoch増やしてみるか、
ダウンロード (3).png

ダメぽい
ダウンロード (4).png

一旦、仕組みを理解する

なるほど、
入力ベクトルの末尾に条件ベクトルを結合するのか、
他にも入力ノイズに情報を持たせたりも出来るのか、

今回は入力ベクトルの末尾に条件ベクトルを結合するのでやってみる

出来た!
5ぽい
ダウンロード (5) (1).png
ConditionalDisoriminatorの
self.label_embedding = nn.Embedding(10, 784)
これが
self.label_embedding = nn.Embedding(10, 10)
こう

なぜか画像のサイズに合わせたベクトルに変換してた、、

ConditionalGenerator

ラベル情報を追加

class ConditionalGenerator(nn.Module):
    def __init__(self):
        super(ConditionalGenerator, self).__init__()
        self.label_embedding = nn.Embedding(10, 10)  # 0-9の数字のエンベディング
        self.main = nn.Sequential(
            nn.Linear(110, 256),  # 100次元のノイズ + 10次元のエンコードされたラベル
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        # ラベルをエンコードし、ノイズと結合
        embedded_labels = self.label_embedding(labels)
        input = torch.cat((noise, embedded_labels), -1)
        return self.main(input).view(-1, 1, 28, 28)

self.label_embedding = nn.Embedding(10, 10)
ラベルをエンベディングするやつ

nn.Linear(110, 256)
100次元のノイズ + 10次元のエンコードされたラベル

embedded_labels = self.label_embedding(labels)
input = torch.cat((noise, embedded_labels), -1)
ここでラベルのエンコードと結合を行う

ConditionalDiscriminator

ラベル情報を追加

class ConditionalDiscriminator(nn.Module):
    def __init__(self):
        super(ConditionalDiscriminator, self).__init__()
        self.label_embedding = nn.Embedding(10, 10)  # ラベルを10次元ベクトルにエンコード
        self.main = nn.Sequential(
            nn.Linear(784 + 10, 512),  # 784次元の画像 + 10次元のラベル
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        embedded_labels = self.label_embedding(labels)  # ラベルをエンベディングベクトルに変換
        input = torch.cat((img.view(img.size(0), -1), embedded_labels), -1)  # 画像とエンベディングを結合
        return self.main(input)

ジェネレーターと同じですね、

トレーニングループ

num_epochs = 100
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(dataloader):
        # 本物のラベルと偽物のラベルを準備
        real_labels = torch.ones(images.size(0), 1)
        fake_labels = torch.zeros(images.size(0), 1)

        # ノイズとラベルを生成
        noise = torch.randn(images.size(0), 100)
        fake_images = generator(noise, labels)

        # ディスクリミネーターのトレーニング
        optimizer_d.zero_grad()
        outputs = discriminator(images, labels)
        d_loss_real = criterion(outputs, real_labels)
        outputs = discriminator(fake_images.detach(), labels)
        d_loss_fake = criterion(outputs, fake_labels)
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_d.step()

        # ジェネレーターのトレーニング
        optimizer_g.zero_grad()
        outputs = discriminator(fake_images, labels)
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        optimizer_g.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item()}, g_loss: {g_loss.item()}')

ラベルを活用して学習してます、

生成された画像の表示

def generate_and_show_all_digits():
    # 0から9のラベルを作成
    labels = torch.arange(0, 10).long()
    
    # 各ラベルに対応するノイズを生成
    noise = torch.randn(10, 100)
    
    # 画像を生成
    with torch.no_grad():
        fake_images = generator(noise, labels).reshape(-1, 28, 28)
    
    # 画像を表示
    plt.figure(figsize=(10, 1))
    for i in range(10):
        plt.subplot(1, 10, i+1)
        plt.imshow(fake_images[i], cmap='gray')
        plt.title(str(i))
        plt.axis('off')
    plt.show()

# 0から9の数字を生成して表示
generate_and_show_all_digits()

ダウンロード.png

いい感じ!

学習ごとの損失の変化について

学習セクションごとで損失が60だったり1だったりと大きく変わることがあった
ChatGPTによると

  • 重みの初期化
  • データシャッフル
  • 学習率スケジュール
  • モデルの不安定性
  • ハードウェアの影響

等が挙げられた

割とGANの性質上、運ゲーみたいな所があった、

最後に

2年前は1mmも分からなかったが出来るようになってうれしい!!
ChatGPTのすごさをさらに体感した、
もっと頑張ります!

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1