1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

VAEはなぜAEのアーキテクチャを採用するのか?CNNでも可能?(実装)

Last updated at Posted at 2025-02-26

1. 要約

本記事ではVAEをCNNで実装しました。

また、以下の点にも焦点を当てています

  • VAEがAE(Autoencoder)のアーキテクチャを採用する理由
  • CNNを用いるメリット

VAEの基礎は過去の記事で解説しています:

コード:

2. CNNを用いたVAEの構築

線形結合層(Full connection layer)を用いたVAEの特徴

まず、データをFCに入力可能な状態にするために、データをフラットな状態に変換します。
その後、NNを用いて潜在空間に任意の次元数まで圧縮します。
そこで得られた潜在変数を用いて、再度情報を元に戻し、再構成誤差の低下と潜在空間の確率分布が設定した分布と近くなるように学習します。

CNNをVAEに用いるメリット

CNNは空間の特徴量を取得できると考えられています。
学習させたいものが画像である場合、空間情報は大切であると考えられるので、VAEの中にCNNを組み込む動機が与えられます。

加えて、学習するパラメータの数も少なくすることが可能です(=計算コストの削減)。

CNNを用いたVAEの構造

Encoder(初期案)

スクリーンショット 2025-02-22 16.44.31.png

Decoder

スクリーンショット 2025-02-22 16.44.22.png

注意点

  • Encoder:潜在確率分布の平均や分散を求める時にもCNNを利用している

3. 実装

コード全体はこちら:

Encoder
class Encoder(nn.Module):
    def __init__(self, latent_dim, times, batch_size):
        super(Encoder, self).__init__()
        self.kernel_num = 32 * times
        self.conv1 = nn.Conv2d(1, self.kernel_num, kernel_size=3, stride=1, padding=1)  # Feature extraction
        self.conv2 = nn.Conv2d(self.kernel_num, self.kernel_num, kernel_size=3, stride=1, padding=1)  # Maintain 32 channels
        self.conv3 = nn.Conv2d(self.kernel_num, self.kernel_num, kernel_size=3, stride=1, padding=1)  # Maintain 32 channels
        self.downsample1 = nn.Conv2d(self.kernel_num, self.kernel_num, kernel_size=3, stride=2, padding=1)  # Downsampling
        self.downsample2 = nn.Conv2d(self.kernel_num, self.kernel_num, kernel_size=3, stride=2, padding=1)  # Further Downsampling
        self.downsample3 = nn.Conv2d(self.kernel_num, self.kernel_num, kernel_size=3, stride=2, padding=1)  # Further Downsampling
        self.conv_for_mu = nn.Conv2d(self.kernel_num, self.kernel_num//2, kernel_size=3, stride=1, padding=1)
        self.conv_for_logvar = nn.Conv2d(self.kernel_num, self.kernel_num//2, kernel_size=3, stride=1, padding=1)
        self.activation = nn.SiLU()
        self.batch_size = batch_size

    def forward(self, x):
        x = self.activation(self.conv1(x))
        x = self.activation(self.downsample1(x))
        x = self.activation(self.conv2(x))
        x = self.activation(self.downsample2(x))
        x = self.activation(self.conv3(x))
        x = self.activation(self.downsample3(x))
        mu = self.activation(self.conv_for_mu(x))
        logvar = self.activation(self.conv_for_logvar(x))
        mu = torch.flatten(mu, start_dim=1)
        logvar = torch.flatten(logvar, start_dim=1)
        sigma = torch.exp(0.5 * logvar)
        return mu, sigma

ここではダウンサンプリングを3回行っている。

最初の画像のサイズが28 * 28 -> 14 * 14 -> 7 * 7 -> 4 * 4となっている。

4 * 4に画像を圧縮後、平均を抽出するようのCNNに通し(B=32, C=16, W=4, H=4)とし、これを平坦化(=flatten)することで平均の特徴量を抽出させようとしている。

decoder
class Decoder(nn.Module):
    def __init__(self, latent_dim, times, batch_size):
        super(Decoder, self).__init__()
        self.kernel_num = 32 * times
        self.conv1 = nn.Conv2d(self.kernel_num//2, self.kernel_num, kernel_size=3, stride=1, padding=1)  # Feature extraction
        self.conv2 = nn.Conv2d(self.kernel_num, self.kernel_num//2, kernel_size=3, stride=1, padding=1)  # Maintain 32 channels
        self.conv3 = nn.Conv2d(self.kernel_num//2, self.kernel_num//4, kernel_size=3, stride=1, padding=1)  # Maintain 32 channels
        self.upsample1 = nn.ConvTranspose2d(self.kernel_num, self.kernel_num, kernel_size=3, stride=2, padding=1,  output_padding=0) # to adjust expected size of outputs
        self.upsample2 = nn.ConvTranspose2d(self.kernel_num//2, self.kernel_num//2, kernel_size=3, stride=2, padding=1,  output_padding=1)
        self.upsample3 = nn.ConvTranspose2d(self.kernel_num//4, self.kernel_num//4, kernel_size=3, stride=2, padding=1,  output_padding=1)
        self.deconv_final = nn.ConvTranspose2d(self.kernel_num//4, 1, kernel_size=3, stride=1, padding=1)
        self.activation = nn.SiLU()
        self.batch_size = batch_size

    def forward(self, z):
        # print(z.shape) # (batch_size, 256)
        x = z.view(self.batch_size, self.kernel_num//2, 4, 4)
        x = self.activation(self.conv1(x))
        x = self.activation(self.upsample1(x))
        x = self.activation(self.conv2(x))
        x = self.activation(self.upsample2(x))
        x = self.activation(self.conv3(x))
        x = self.activation(self.upsample3(x))
        x = self.deconv_final(x)
        x_hat = F.sigmoid(x)
        return x_hat

アップサンプリングを3回組み合わせて、画像を28 * 28に次元拡張をしている。

reparametrize tirck
def reparameterize(mu, sigma):
    eps = torch.randn_like(sigma)
    return mu + sigma * eps

変数変換トリックの部分。

VAE
class VAE(nn.Module):
    def __init__(self, latent_dim, times, batch_size):
        super(VAE, self).__init__()
        self.encoder = Encoder(latent_dim=latent_dim, times=times, batch_size=batch_size)
        self.decoder = Decoder(latent_dim=latent_dim, times=times, batch_size=batch_size)

    def get_loss(self, x):
        mu, sigma = self.encoder(x)
        z = reparameterize(mu, sigma)
        x_hat = self.decoder(z)

        batch_size = len(x)
        L1 = F.mse_loss(x_hat, x, reduction='sum')
        L2 = - torch.sum(1 + torch.log(sigma**2) - mu**2 - sigma**2)
        return (L1 + L2) / batch_size

4. 学習結果

学習時間はColabの無料で使えるGPUでおよそ13分かかりました。

損失曲線の確認

損失曲線は学習途中までは順調に減少。学習が局所最適解に陥ったと思われた部分もあるが一応うまく行っているっぽい。

スクリーンショット 2025-02-26 20.19.52.png

出力結果

スクリーンショット 2025-02-26 20.20.54.png

いい結果ではないです。

パラメータが足りないのだろうか?

パラメータを増やしてみます。3倍に。学習時間はほとんど変わらず14分くらい。

Before Architecture:

スクリーンショット 2025-02-26 20.21.23.png

After Architecture

スクリーンショット 2025-02-26 20.22.01.png

CNNのパラメータ数が32->96とのが大きな違い。

損失曲線

学習は失敗。

スクリーンショット 2025-02-26 20.22.19.png

出力結果

結果は真っ黒

スクリーンショット 2025-02-26 20.23.04.png

パラメータを増やした結果、学習が安定しない。

考えられる原因

  • 正則化などを入れなかったので勾配が消失した?
  • 潜在変数の次元数が多く、複雑に絡まった情報の前半を平均、後半を分散の推定値と雑に利用しているせい?

結果はだめでした。

考察

CNNのアーキテクチャで学習がうまくかなかった原因を2つ考察します。

  • CNNで抽出したデータをそのまま平均や分散の推定値として用いていること(CNNは画像のエッジや辺の特徴量の抽出しているので、それが潜在分布の平均や分散を推定するのに役立たない)

以下でアーキテクチャに若干の変更を加え画像を生成しています。

5. 実験(追加トピック)

FC層を追加

アーキテクチャ

Encoderの最終部分の平均と分散を抽出する部分に線形変換を追加する。
コードでいうと以下:

Encoder
class Encoder(nn.Module):
    def __init__(self, latent_dim, times, batch_size):
        super(Encoder, self).__init__()
        self.kernel_num = 32 * times
        self.conv1 = nn.Conv2d(1, self.kernel_num, kernel_size=3, stride=1, padding=1)  # Feature extraction
        self.conv2 = nn.Conv2d(self.kernel_num, self.kernel_num, kernel_size=3, stride=1, padding=1)  # Maintain 32 channels
        self.conv3 = nn.Conv2d(self.kernel_num, self.kernel_num, kernel_size=3, stride=1, padding=1)  # Maintain 32 channels
        self.downsample1 = nn.Conv2d(self.kernel_num, self.kernel_num, kernel_size=3, stride=2, padding=1)  # Downsampling
        self.downsample2 = nn.Conv2d(self.kernel_num, self.kernel_num, kernel_size=3, stride=2, padding=1)  # Further Downsampling
        self.downsample3 = nn.Conv2d(self.kernel_num, self.kernel_num, kernel_size=3, stride=2, padding=1)  # Further Downsampling
        self.fc_for_mean = nn.Linear(self.kernel_num * 4 * 4, 4 * 4)
        self.fc_for_logvar = nn.Linear(self.kernel_num * 4 * 4, 4 * 4)
        self.activation = nn.SiLU()
        self.batch_size = batch_size

    def forward(self, x):
        x = self.activation(self.conv1(x))
        x = self.activation(self.downsample1(x))
        x = self.activation(self.conv2(x))
        x = self.activation(self.downsample2(x))
        x = self.activation(self.conv3(x))
        x = self.activation(self.downsample3(x))
        x = torch.flatten(x, start_dim=1) # size 32(=barch size) * 4 * 4(=latent_dim)
        # Treat the first half as the mean and the second half as the sigma(variance)
        mu = self.fc_for_mean(x)
        logvar = self.fc_for_logvar(x)
        sigma = torch.exp(0.5 * logvar)
        return mu, sigma

損失曲線

滑らかな損失曲線が確認できた
スクリーンショット 2025-02-26 20.27.03.png

結果

CNNのみで構築したアーキテクチャよりもいい感じの結果が出ました。

スクリーンショット 2025-02-26 20.28.01.png

アーキテクチャのパラメータ増加と画質の変化

AIに期待したいスケーリング能力を調べたい。ということでモデルのCNN+FCのモデルパラメーターを3倍にする。
すると学習がうまくいかなかった。

スクリーンショット 2025-02-26 20.32.30.png

Nomarilze layerを追加すると

パラメータの数3倍、かつ正則化のためのアーキテクチャを追加しました。学習時間が20分と6分伸びましたが結果は以下です。勾配消失のような問題も起こらず、学習曲線も安定していました。

スクリーンショット 2025-02-26 21.02.02.png

カラフルな画像の学習方法

今回はMNISTの手書きの数字の画像を学習させました。
これは白黒なので学習画像のサイズは$(Batch-size, channels, width, height) = (任意, 1, 28, 28)$となっています。
カラフルな画像の場合、(R,B,G)のパラメータが追加されるので、$Channels=3$と変更すれば学習できます。

もし、さらに画像生成の質を向上させるならば

以下の機能を追加すると思います。

  • Self-Attention
    => 理由:画像の各要素同士の関係を学習できるため
  • Residiaul-connection
    =>画像生成系のタスクでよく使われているため

いわゆる、U-netと呼ばれる構造になります。

6. 参考文献

今回利用した活性化関数についてSiLU(ReLUに似ており、マイナスの勾配も計算できることから採用) :

CNNによる次元拡大:

・ゼロから作るDeep Learning ❺:

VAEをゆるっと解説:

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?