LoginSignup
0
1

今更ながら生成モデルGANについての勉強です。
論文と構造が異なる部分がありますので、ご了承ください。
また、コード内のコメントはchatgptによるものです。

以下の論文について実装を行っていきます。

タイトル: UNSUPERVISED REPRESENTATION LEARNING WITH DEEP CONVOLUTIONAL GENERATIVE ADVERSARIAL NETWORKS
https://arxiv.org/pdf/1511.06434.pdf

しばらくうまく学習できていなかったのですが、最終的に以下の本の構造を参考にしました。

DCGAN

通常の畳み込みを使ったGANとの違いは、

  • プーリング層を使わず識別器ではConvolution層のstrideで、生成器ではfractional-strided-convolution(逆畳み込み)で画像のサイズを変換する
  • 識別器・生成器共にBatchNormalizationを使用
  • 全結合層を除外
  • 生成器における活性化関数は出力のTanhを除いてすべてReLUを使用する
  • 識別器ではすべての活性化関数をLeakyReLUとする(確率を出力する際はsigmoidを使用)

LeakyReLUの$\alpha$は0.2とします。

生成器の構造を図に示す。

image.png

最適化手法はAdamを使用し、学習率は0.0002、$\beta_1$は0.5とする。

識別器の損失関数

GANの損失関数は次の式となる。
$$
\min_G\max_D L(D,G)=E_{x\sim p_{data}(x)}[\log D(x)]+E_{z\sim p_{z}(z)}[\log(1- D(G(z)))]
$$
損失関数$L(D,G)$は識別器が最大化、生成器が最小化する。
識別器と生成器は交互に学習するため、$L(D)$と$L(G)$の2つに分けて考える。
識別器の損失関数$L(d)$は符号を逆にして、最小化問題として考える。
$$
L(D,G)=-E_{x\sim p_{data}(x)}[\log D(x)]-E_{z\sim p_{z}(z)}[\log(1- D(G(z)))]
$$
本物画像の入力時は識別信号$D(x)$を出力し、生成画像の入力時は識別信号$D(G(z))$を出力する。
第1項は$D(x)$が1のときに損失が0となり、第2項は$D(G(z))$が0のときに損失が0となる。
識別器の損失関数はバイナリークロスエントロピーで定式化できる。
$$
J^D=-y\log D(x)-(1-y)\log (1-D(G(z)))
$$

生成器の損失関数

$L(D,G)$の最小化を行う。
$$
\min_G\max_D L(D,G)=E_{x\sim p_{data}(x)}[\log D(x)]+E_{z\sim p_{z}(z)}[\log(1- D(G(z)))]
$$
第1項はノイズに対して定数となるので、第2項だけ使用する。
$$
L(G)=E_{z\sim p_{z}(z)}[\log(1- D(G(z)))]
$$
識別器をだますことができれば、$D(G(z))=1$で損失関数はマイナス無限大となり最小化する。
この関数であると、0付近で勾配が小さく、学習が進まないという問題があるので次のように書き替える。
$$
L(G)=-E_{z\sim p_{z}(z)}[\log(D(G(z)))]
$$
生成器の損失関数もバイナリークロスエントロピーで定式化できる。
$$
J^G=-J^D=y\log D(x)+(1-y)\log (1-D(G(z)))
$$
生成器では$y=0$の生成画像クラスだけを使用する
$$
J^G=\log (1-D(G(z)))
$$
さらに、$L(G)$と同様に書き換える。
$$
J^G=-\log D(G(z))
$$
この式はバイナリークロスエントロピーの$J^G=-y\log D(G(z))$の$y=1$の式と同じで、正解ラベルは$y=1$と考える。

実装(keras)

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Conv2D, Conv2DTranspose, BatchNormalization, Activation, ReLU, Reshape,LeakyReLU, Dropout, Flatten, Input
from tensorflow.keras.optimizers import Adam
from keras.datasets import mnist

import numpy as np
import cv2

import matplotlib.pyplot as plt

データはmnistを使用します。
データの読み込みを行っておきます。

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
x_train=x_train.astype('float32')
x_test.astype('float32')
x_train = x_train/127.5 - 1
x_test = x_test/127.5 - 1

可視化用の関数を定義します。

def generate_images(i, type='keras'):
    """指定されたタイプに応じて画像を生成して表示する関数。

    Args:
        i (int): インデックス。
        type (str, optional): 生成器のタイプ。'keras'または'numpy'を指定。デフォルトは'keras'"""
    n_rows = 16
    n_cols = 16

    if type == 'numpy':
        # numpyの場合、ランダムなノイズを生成して順伝播で画像を生成
        noise = np.random.normal(0, 1, (n_rows * n_cols, n_noise))
        g_imgs = forward_propagation(noise, gen_layers)
    elif type == 'keras':
        # kerasの場合、ランダムなノイズを生成し、Generatorを通じて画像を生成
        noise = np.random.normal(0, 1, (n_rows * n_cols, n_noise))
        g_imgs = np.array(gan.gen(noise))
    else:
        # torchの場合、ランダムなノイズを生成し、Generatorを通じて画像を生成
        noise = torch.normal(0, 1, (n_rows * n_cols, n_noise))
        g_imgs = generator(noise)
        g_imgs = g_imgs.detach().numpy()

    g_imgs = g_imgs / 2 + 0.5

    img_size_spaced = img_size + 2

    matrix_image = np.zeros((img_size_spaced * n_rows,
                             img_size_spaced * n_cols))

    for r in range(n_rows):
        for c in range(n_cols):
            g_img = g_imgs[r * n_cols + c].reshape(img_size, img_size)
            top = r * img_size_spaced
            left = c * img_size_spaced
            matrix_image[top:top + img_size, left:left + img_size] = g_img

    # 生成された画像を表示
    plt.figure(figsize=(8, 8))
    plt.imshow(matrix_image.tolist(), cmap='Greys_r')
    plt.tick_params(labelbottom=False, labelleft=False, bottom=False, left=False)
    plt.show()
class gen_conv(Model):
    """Generator Modelの畳み込みレイヤーを定義するクラス。

    Args:
        out_channels (int): 出力チャンネル数。
        kernel_size (int): カーネルサイズ。
        stride (int): ストライドサイズ。
        padding (str): パディングの方法。'same'または'valid'を指定。
        bias (bool): バイアス項を使用するかどうか。
        final (bool): 最終層かどうかを示すフラグ。
    """
    def __init__(self, out_channels, kernel_size=3, stride=2, padding='same', bias=False, final=False):
        super().__init__()
        if final:
            # 最終層の場合、Conv2Dレイヤーを使用し、活性化関数にtanhを適用
            self.conv = Conv2D(out_channels,
                               kernel_size=kernel_size,
                               strides=stride,
                               padding=padding,
                               use_bias=bias)
            self.bn = None
            self.act = Activation("tanh")
        else:
            # 最終層でない場合、Conv2DTransposeレイヤーを使用し、BatchNormalizationとReLUを適用
            self.conv = Conv2DTranspose(out_channels,
                                        kernel_size=kernel_size,
                                        strides=stride,
                                        padding=padding,
                                        use_bias=bias)
            self.bn = BatchNormalization(momentum=0.8)
            self.act = ReLU()

    def call(self, x):
        """モデルのフォワードパスを定義する関数。

        Args:
            x (tensor): 入力テンソル。

        Returns:
            tensor: 出力テンソル。
        """
        x = self.conv(x)
        if self.bn:
            x = self.bn(x)
        x = self.act(x)
        return x

class Generator(Model):
    """Generatorモデルを定義するクラス。

    Attributes:
        linear (Dense): 128*7*7次元の全結合層。
        gen1 (gen_conv): 第1層のgen_convオブジェクト。
        gen2 (gen_conv): 第2層のgen_convオブジェクト。
        gen3 (gen_conv): 最終層のgen_convオブジェクト。
    """
    def __init__(self):
        super().__init__()
        self.linear = Dense(128*7*7, activation="relu")

        self.gen1 = gen_conv(out_channels=128, kernel_size=3, stride=2, padding='same')
        self.gen2 = gen_conv(out_channels=64, kernel_size=3, stride=2, padding='same')
        self.gen3 = gen_conv(out_channels=1, kernel_size=3, stride=1, padding='same', final=True)

    def call(self, x):
        """モデルのフォワードパスを定義する関数。

        Args:
            x (tensor): 入力テンソル。

        Returns:
            tensor: 出力テンソル。
        """
        x = self.linear(x)
        x = Reshape(target_shape=(7, 7, 128))(x)
        x = self.gen1(x)
        x = self.gen2(x)
        x = self.gen3(x)
        return x
gen = Generator()
gen.build((None,100))  # build with input shape.
dummy_input = Input(shape=(100))  # declare without batch demension.
model_summary = Model(inputs=[dummy_input], outputs=gen.call(dummy_input))
model_summary.summary()

image.png

class disc_conv(Model):
    """Discriminator Modelの畳み込みレイヤーを定義するクラス。

    Args:
        out_channels (int): 出力チャンネル数。
        kernel_size (int, optional): カーネルサイズ。デフォルトは5。
        stride (int, optional): ストライドサイズ。デフォルトは2。
        padding (str, optional): パディングの方法。'same'または'valid'を指定。デフォルトは'same'。
        bias (bool, optional): バイアス項を使用するかどうか。デフォルトはFalse。
    """
    def __init__(self, out_channels, kernel_size=5, stride=2, padding='same', bias=False):
        super().__init__()
        self.act = LeakyReLU(alpha=0.2)
        self.drop = Dropout(0.25)
        self.conv = Conv2D(out_channels,
                           kernel_size=kernel_size,
                           strides=stride,
                           padding=padding,
                           use_bias=bias)

    def call(self, x):
        """モデルのフォワードパスを定義する関数。

        Args:
            x (tf.Tensor): 入力テンソル。

        Returns:
            tf.Tensor: 出力テンソル。
        """
        x = self.conv(x)
        x = self.drop(x)
        x = self.act(x)
        return x

class Discriminator(Model):
    """Discriminatorモデルを定義するクラス。

    Attributes:
        disc1 (disc_conv): 第1層のdisc_convオブジェクト。
        disc2 (disc_conv): 第2層のdisc_convオブジェクト。
        disc3 (disc_conv): 第3層のdisc_convオブジェクト。
        disc4 (disc_conv): 第4層のdisc_convオブジェクト。
        flat (Flatten): Flattenレイヤー。
        out (Dense): 1次元の全結合層(出力層)。
    """
    def __init__(self):
        super().__init__()
        self.disc1 = disc_conv(out_channels=32, kernel_size=3, stride=2, padding='same')
        self.disc2 = disc_conv(out_channels=64, kernel_size=3, stride=2, padding='same')
        self.disc3 = disc_conv(out_channels=128, kernel_size=3, stride=2, padding='same')
        self.disc4 = disc_conv(out_channels=256, kernel_size=3, stride=1, padding='same')

        self.flat = Flatten()
        self.out = Dense(1, activation='sigmoid')

    def call(self, x):
        """モデルのフォワードパスを定義する関数。

        Args:
            x (tf.Tensor): 入力テンソル。

        Returns:
            tf.Tensor: 出力テンソル。
        """
        x = self.disc1(x)
        x = self.disc2(x)
        x = self.disc3(x)
        x = self.disc4(x)
        x = self.flat(x)
        x = self.out(x)
        return x
disc = Discriminator()
disc.build((None,28,28,1))  # build with input shape.
dummy_input = Input(shape=(28,28,1))  # declare without batch demension.
model_summary = Model(inputs=[dummy_input], outputs=disc.call(dummy_input))
model_summary.summary()

image.png

class GAN(Model):
    """GAN(Generative Adversarial Network)モデルを定義するクラス。

    Args:
        Discriminator (Model): Discriminatorモデルのインスタンス。
        Generator (Model): Generatorモデルのインスタンス。
    """
    def __init__(self, Discriminator, Generator):
        super().__init__()
        self.gen = Generator
        self.disc = Discriminator

    def call(self, x):
        """モデルのフォワードパスを定義する関数。

        Args:
            x (tf.Tensor): 入力テンソル。

        Returns:
            tf.Tensor: 出力テンソル。
        """
        x = self.gen(x)
        x = self.disc(x)
        return x
gan = GAN(disc, gen)
gan.build((None,100))  # build with input shape.
dummy_input = Input(shape=(100))  # declare without batch demension.
model_summary = Model(inputs=[dummy_input], outputs=gan.call(dummy_input))
model_summary.summary()

image.png

# Discriminatorの設定
optim = Adam(lr=2e-5, beta_1=0.2)
discriminator = Discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=optim, metrics=['accuracy'])
# discriminatorのcompile後にFalseとする(compileしたものには影響がない)
discriminator.trainable = False

# Generatorの設定
optim = Adam(lr=2e-5, beta_1=0.2)
generator = Generator()

# GANの定義
gan = GAN(discriminator, generator)
gan.compile(loss='binary_crossentropy', optimizer=optim, metrics=['accuracy'])
# ハイパーパラメータ設定
n_noise = 100
n_learn = 3001
interval = 100
batch_size = 128
img_size = 28

# DiscriminatorとGeneratorへの入力用のラベル
real = np.ones((batch_size, 1)).astype(int)
fake = np.zeros((batch_size, 1)).astype(int)

# GANのエラーと精度を記録するための配列を初期化
gan_error_record = np.zeros((n_learn, 2))
gan_acc_record = np.zeros((n_learn, 2))

for i in range(n_learn):

    #----------------------
    # Discriminatorの訓練
    #----------------------

    # バッチサイズ分のランダムなデータを取得
    rand_ids = np.random.randint(len(x_train), size=batch_size)
    imgs_real = x_train[rand_ids]

    # ノイズを生成してGeneratorを通じて偽の画像を生成
    noise = np.randn(0, 1, (batch_size, n_noise))
    imgs_fake = gan.gen(noise)

    # Discriminatorの訓練:実際のデータを本物(real)として、偽のデータを偽物(fake)として訓練
    d_loss_real = discriminator.train_on_batch(imgs_real, real)
    d_loss_fake = discriminator.train_on_batch(imgs_fake, fake)
    d_loss, accuracy = np.add(d_loss_real, d_loss_fake) * 0.5

    # エラーと精度を記録
    gan_error_record[i][0] = d_loss
    gan_acc_record[i][0] = accuracy

    #----------------------
    # Generatorの訓練
    #----------------------

    # ノイズを生成してGeneratorを通じて偽の画像を生成
    noise = np.random.normal(0, 1, (batch_size, n_noise))
    gen_imgs = gan.gen(noise)

    # Generatorの訓練:生成された偽の画像を本物(real)として訓練
    g_loss, accuracy = gan.train_on_batch(noise, real)

    # エラーと精度を記録
    gan_error_record[i][1] = g_loss
    gan_acc_record[i][1] = accuracy

    # インターバルごとに進捗を表示し、生成された画像を表示
    if i % interval == 0:
        print("n_learn:", i)
        print("Error_fake:", gan_error_record[i][0],
              "Acc_fake:", gan_acc_record[i][0])
        print("Error_real:", gan_error_record[i][1],
              "Acc_real:", gan_acc_record[i][1])
        generate_images(i, type='keras')

image.png

image.png

実装(pytorch)

import torch
import torch.nn as nn
import torch.optim as optimizers
from torchsummary import summary

import numpy as np
class gen_conv(nn.Module):
    """Generator用のConvulutionを定義するクラス"""

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False, final=False):
        super().__init__()
        if final:
            # 最終層の場合はConv2dとTanhを使用
            self.conv = nn.Conv2d(in_channels,
                                  out_channels,
                                  kernel_size=kernel_size,
                                  stride=stride,
                                  padding=padding,
                                  bias=bias)
            self.bn = False  # 最終層ではBatchNormを使わない
            self.act = nn.Tanh()  # 最終層ではTanhを使用

        else:
            # 最終層でない場合はConvTranspose2dとReLUを使用
            self.conv = nn.ConvTranspose2d(in_channels,
                                           out_channels,
                                           kernel_size=kernel_size,
                                           stride=stride,
                                           padding=padding,
                                           bias=bias)
            self.bn = nn.BatchNorm2d(out_channels, momentum=0.8)  # BatchNormを使う
            self.act = nn.ReLU()  # ReLUを使用

    def forward(self, x):
        x = self.conv(x)
        if self.bn:
            x = self.bn(x)
        x = self.act(x)
        return x

class Generator(nn.Module):
    """Generatorモデルを定義するクラス"""

    def __init__(self, z_dim=100, image_dim=1, img_size=28, hidden_dim=128):
        super().__init__()
        self.linear = nn.Linear(in_features=z_dim, out_features=128*7*7)  # 全結合層
        self.act = nn.ReLU()  # ReLUを使用

        # 3つのgen_conv層を定義
        self.gen1 = gen_conv(in_channels=hidden_dim, out_channels=128, kernel_size=4, stride=2, padding=1)
        self.gen2 = gen_conv(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1)
        self.gen3 = gen_conv(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, final=True)

    def forward(self, x):
        x = self.linear(x)
        x = self.act(x)
        x = x.view(-1, 128, 7, 7)  # reshape
        x = self.gen1(x)
        x = self.gen2(x)
        x = self.gen3(x)
        return x

class disc_conv(nn.Module):
    """Discriminator用のConvulutionを定義するクラス

    Args:
        in_channels (int): 入力チャネル数
        out_channels (int): 出力チャネル数
        kernel_size (int, optional): カーネルサイズ (デフォルト値: 5)
        stride (int, optional): ストライド (デフォルト値: 2)
        padding (int, optional): パディング (デフォルト値: 1)
        bias (bool, optional): バイアスを使用するかどうか (デフォルト値: False)
    """

    def __init__(self, in_channels, out_channels, kernel_size=5, stride=2, padding=1, bias=False):
        super().__init__()
        self.act = nn.LeakyReLU(negative_slope=0.2)  # LeakyReLUを使用
        self.drop = nn.Dropout(0.25)  # Dropoutを使用
        self.conv = nn.Conv2d(in_channels,
                              out_channels,
                              kernel_size=kernel_size,
                              stride=stride,
                              padding=padding,
                              bias=bias)

    def forward(self, x):
        """順伝播処理

        Args:
            x (torch.Tensor): 入力テンソル

        Returns:
            torch.Tensor: 出力テンソル
        """
        x = self.conv(x)
        x = self.drop(x)
        x = self.act(x)
        return x

class Discriminator(nn.Module):
    """Discriminatorモデルを定義するクラス

    Args:
        img_size (int, optional): 入力画像のサイズ (デフォルト値: 28)
        hidden_dim (int, optional): 隠れ層の次元数 (デフォルト値: 128)
    """

    def __init__(self, img_size=28, hidden_dim=128):
        super().__init__()
        self.disc1 = disc_conv(in_channels=1, out_channels=32, kernel_size=4, stride=2, padding=1)
        self.disc2 = disc_conv(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1)
        self.disc3 = disc_conv(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1)
        self.disc4 = disc_conv(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)

        self.out = nn.Linear(in_features=256*4*4, out_features=1)  # 全結合層
        self.act = torch.sigmoid  # sigmoid関数を使用

    def forward(self, x):
        """順伝播処理

        Args:
            x (torch.Tensor): 入力テンソル

        Returns:
            torch.Tensor: 出力テンソル
        """
        x = self.disc1(x)
        x = self.disc2(x)
        x = self.disc3(x)
        x = self.disc4(x)
        x = nn.Flatten()(x)  # テンソルを1次元に変換
        x = self.out(x)
        x = self.act(x)
        return x.squeeze()  # テンソルのサイズを調整
model = Generator()
summary(model, (1,100))

image.png

model = Discriminator()
summary(model, (1,28,28))

image.png

# DiscriminatorとGeneratorのインスタンスを作成
discriminator = Discriminator()
generator = Generator()

# ウェイトの初期化関数
def weights_init(m):
    """モデルのウェイトを初期化する関数

    Args:
        m (nn.Module): ニューラルネットワークの層
    """
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.xavier_uniform_(m.weight, 1.0)  # 重みをXavierで初期化
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal(m.weight, 0.0, 1.0)  # バッチ正規化層の重みを正規分布で初期化
        torch.nn.init.constant_(m.bias, 0)  # バッチ正規化層のバイアスを0で初期化

# GeneratorとDiscriminatorのウェイトを初期化
generator = generator.apply(weights_init)
discriminator = discriminator.apply(weights_init)

# 損失関数を定義
criterion = nn.BCELoss()

# オプティマイザを定義
optimizerD = torch.optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.2, 0.999), eps=1e-7)
optimizerG = torch.optim.Adam(generator.parameters(), lr=2e-4, betas=(0.2, 0.999), eps=1e-7)

# 学習に関するパラメータ
n_noise = 100
n_learn = 3001
interval = 100
batch_size = 128
img_size = 28

# 正解ラベルと偽ラベルを定義
real = torch.full((batch_size,), 1, dtype=torch.float32)
fake = torch.full((batch_size,), 0, dtype=torch.float32)

# GANの学習過程の記録用の配列を初期化
gan_error_record = np.zeros((n_learn, 2))
gan_acc_record = np.zeros((n_learn, 2))

for i in range(n_learn):
    #----------------------
    # Discriminatorの訓練
    #----------------------
    # 正解画像で損失の計算(正解ラベル: 1)
    rand_ids = np.random.randint(len(x_train), size=batch_size)
    imgs_real = torch.tensor(x_train[rand_ids], dtype=torch.float32)
    imgs_real = imgs_real.transpose(3, 1).transpose(3, 2)

    discriminator.zero_grad()
    output = discriminator(imgs_real)
    d_loss_real = criterion(output, real)

    d_x = output.mean().item()

    # 生成画像で損失の計算(正解ラベル: 0)
    noise = torch.randn(batch_size, n_noise)
    fake_image = generator(noise)

    output = discriminator(fake_image)
    d_loss_fake = criterion(output, fake)

    d_g_z1 = output.mean().item()

    # 2つの損失を足して訓練
    d_loss = (d_loss_real + d_loss_fake) * 0.5
    d_loss.backward()
    optimizerD.step()

    #----------------------
    # Generatorの訓練
    #----------------------
    # 生成画像で損失の計算(正解ラベル: 1)
    noise = torch.randn(batch_size, n_noise)
    fake_image = generator(noise)

    generator.zero_grad()
    output = discriminator(fake_image)
    g_loss = criterion(output, real)

    # 訓練
    g_loss.backward()
    d_g_z2 = output.mean().item()
    optimizerG.step()

    if i % interval == 0:
        print("n_learn:", i)
        print(d_g_z1, d_g_z2)
        generate_images(i, type='torch')

image.png

image.png

以上となります。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1