1. 要約
本記事ではVAEをCNNで実装しました。
また、以下の点にも焦点を当てています
- VAEがAE(Autoencoder)のアーキテクチャを採用する理由
- CNNを用いるメリット
VAEの基礎は過去の記事で解説しています:
コード:
2. CNNを用いたVAEの構築
線形結合層(Full connection layer)を用いたVAEの特徴
まず、データをFCに入力可能な状態にするために、データをフラットな状態に変換します。
その後、NNを用いて潜在空間に任意の次元数まで圧縮します。
そこで得られた潜在変数を用いて、再度情報を元に戻し、再構成誤差の低下と潜在空間の確率分布が設定した分布と近くなるように学習します。
CNNをVAEに用いるメリット
CNNは空間の特徴量を取得できると考えられています。
学習させたいものが画像である場合、空間情報は大切であると考えられるので、VAEの中にCNNを組み込む動機が与えられます。
加えて、学習するパラメータの数も少なくすることが可能です(=計算コストの削減)。
CNNを用いたVAEの構造
Encoder(初期案)
Decoder
注意点
- Encoder:潜在確率分布の平均や分散を求める時にもCNNを利用している
3. 実装
コード全体はこちら:
class Encoder(nn.Module):
def __init__(self, latent_dim, times, batch_size):
super(Encoder, self).__init__()
self.kernel_num = 32 * times
self.conv1 = nn.Conv2d(1, self.kernel_num, kernel_size=3, stride=1, padding=1) # Feature extraction
self.conv2 = nn.Conv2d(self.kernel_num, self.kernel_num, kernel_size=3, stride=1, padding=1) # Maintain 32 channels
self.conv3 = nn.Conv2d(self.kernel_num, self.kernel_num, kernel_size=3, stride=1, padding=1) # Maintain 32 channels
self.downsample1 = nn.Conv2d(self.kernel_num, self.kernel_num, kernel_size=3, stride=2, padding=1) # Downsampling
self.downsample2 = nn.Conv2d(self.kernel_num, self.kernel_num, kernel_size=3, stride=2, padding=1) # Further Downsampling
self.downsample3 = nn.Conv2d(self.kernel_num, self.kernel_num, kernel_size=3, stride=2, padding=1) # Further Downsampling
self.conv_for_mu = nn.Conv2d(self.kernel_num, self.kernel_num//2, kernel_size=3, stride=1, padding=1)
self.conv_for_logvar = nn.Conv2d(self.kernel_num, self.kernel_num//2, kernel_size=3, stride=1, padding=1)
self.activation = nn.SiLU()
self.batch_size = batch_size
def forward(self, x):
x = self.activation(self.conv1(x))
x = self.activation(self.downsample1(x))
x = self.activation(self.conv2(x))
x = self.activation(self.downsample2(x))
x = self.activation(self.conv3(x))
x = self.activation(self.downsample3(x))
mu = self.activation(self.conv_for_mu(x))
logvar = self.activation(self.conv_for_logvar(x))
mu = torch.flatten(mu, start_dim=1)
logvar = torch.flatten(logvar, start_dim=1)
sigma = torch.exp(0.5 * logvar)
return mu, sigma
ここではダウンサンプリングを3回行っている。
最初の画像のサイズが28 * 28 -> 14 * 14 -> 7 * 7 -> 4 * 4となっている。
4 * 4に画像を圧縮後、平均を抽出するようのCNNに通し(B=32, C=16, W=4, H=4)とし、これを平坦化(=flatten)することで平均の特徴量を抽出させようとしている。
class Decoder(nn.Module):
def __init__(self, latent_dim, times, batch_size):
super(Decoder, self).__init__()
self.kernel_num = 32 * times
self.conv1 = nn.Conv2d(self.kernel_num//2, self.kernel_num, kernel_size=3, stride=1, padding=1) # Feature extraction
self.conv2 = nn.Conv2d(self.kernel_num, self.kernel_num//2, kernel_size=3, stride=1, padding=1) # Maintain 32 channels
self.conv3 = nn.Conv2d(self.kernel_num//2, self.kernel_num//4, kernel_size=3, stride=1, padding=1) # Maintain 32 channels
self.upsample1 = nn.ConvTranspose2d(self.kernel_num, self.kernel_num, kernel_size=3, stride=2, padding=1, output_padding=0) # to adjust expected size of outputs
self.upsample2 = nn.ConvTranspose2d(self.kernel_num//2, self.kernel_num//2, kernel_size=3, stride=2, padding=1, output_padding=1)
self.upsample3 = nn.ConvTranspose2d(self.kernel_num//4, self.kernel_num//4, kernel_size=3, stride=2, padding=1, output_padding=1)
self.deconv_final = nn.ConvTranspose2d(self.kernel_num//4, 1, kernel_size=3, stride=1, padding=1)
self.activation = nn.SiLU()
self.batch_size = batch_size
def forward(self, z):
# print(z.shape) # (batch_size, 256)
x = z.view(self.batch_size, self.kernel_num//2, 4, 4)
x = self.activation(self.conv1(x))
x = self.activation(self.upsample1(x))
x = self.activation(self.conv2(x))
x = self.activation(self.upsample2(x))
x = self.activation(self.conv3(x))
x = self.activation(self.upsample3(x))
x = self.deconv_final(x)
x_hat = F.sigmoid(x)
return x_hat
アップサンプリングを3回組み合わせて、画像を28 * 28に次元拡張をしている。
def reparameterize(mu, sigma):
eps = torch.randn_like(sigma)
return mu + sigma * eps
変数変換トリックの部分。
class VAE(nn.Module):
def __init__(self, latent_dim, times, batch_size):
super(VAE, self).__init__()
self.encoder = Encoder(latent_dim=latent_dim, times=times, batch_size=batch_size)
self.decoder = Decoder(latent_dim=latent_dim, times=times, batch_size=batch_size)
def get_loss(self, x):
mu, sigma = self.encoder(x)
z = reparameterize(mu, sigma)
x_hat = self.decoder(z)
batch_size = len(x)
L1 = F.mse_loss(x_hat, x, reduction='sum')
L2 = - torch.sum(1 + torch.log(sigma**2) - mu**2 - sigma**2)
return (L1 + L2) / batch_size
4. 学習結果
学習時間はColabの無料で使えるGPUでおよそ13分かかりました。
損失曲線の確認
損失曲線は学習途中までは順調に減少。学習が局所最適解に陥ったと思われた部分もあるが一応うまく行っているっぽい。
出力結果
いい結果ではないです。
パラメータが足りないのだろうか?
パラメータを増やしてみます。3倍に。学習時間はほとんど変わらず14分くらい。
Before Architecture:
After Architecture
CNNのパラメータ数が32->96とのが大きな違い。
損失曲線
学習は失敗。
出力結果
結果は真っ黒
パラメータを増やした結果、学習が安定しない。
考えられる原因
- 正則化などを入れなかったので勾配が消失した?
- 潜在変数の次元数が多く、複雑に絡まった情報の前半を平均、後半を分散の推定値と雑に利用しているせい?
結果はだめでした。
考察
CNNのアーキテクチャで学習がうまくかなかった原因を2つ考察します。
- CNNで抽出したデータをそのまま平均や分散の推定値として用いていること(CNNは画像のエッジや辺の特徴量の抽出しているので、それが潜在分布の平均や分散を推定するのに役立たない)
以下でアーキテクチャに若干の変更を加え画像を生成しています。
5. 実験(追加トピック)
FC層を追加
アーキテクチャ
Encoderの最終部分の平均と分散を抽出する部分に線形変換を追加する。
コードでいうと以下:
class Encoder(nn.Module):
def __init__(self, latent_dim, times, batch_size):
super(Encoder, self).__init__()
self.kernel_num = 32 * times
self.conv1 = nn.Conv2d(1, self.kernel_num, kernel_size=3, stride=1, padding=1) # Feature extraction
self.conv2 = nn.Conv2d(self.kernel_num, self.kernel_num, kernel_size=3, stride=1, padding=1) # Maintain 32 channels
self.conv3 = nn.Conv2d(self.kernel_num, self.kernel_num, kernel_size=3, stride=1, padding=1) # Maintain 32 channels
self.downsample1 = nn.Conv2d(self.kernel_num, self.kernel_num, kernel_size=3, stride=2, padding=1) # Downsampling
self.downsample2 = nn.Conv2d(self.kernel_num, self.kernel_num, kernel_size=3, stride=2, padding=1) # Further Downsampling
self.downsample3 = nn.Conv2d(self.kernel_num, self.kernel_num, kernel_size=3, stride=2, padding=1) # Further Downsampling
self.fc_for_mean = nn.Linear(self.kernel_num * 4 * 4, 4 * 4)
self.fc_for_logvar = nn.Linear(self.kernel_num * 4 * 4, 4 * 4)
self.activation = nn.SiLU()
self.batch_size = batch_size
def forward(self, x):
x = self.activation(self.conv1(x))
x = self.activation(self.downsample1(x))
x = self.activation(self.conv2(x))
x = self.activation(self.downsample2(x))
x = self.activation(self.conv3(x))
x = self.activation(self.downsample3(x))
x = torch.flatten(x, start_dim=1) # size 32(=barch size) * 4 * 4(=latent_dim)
# Treat the first half as the mean and the second half as the sigma(variance)
mu = self.fc_for_mean(x)
logvar = self.fc_for_logvar(x)
sigma = torch.exp(0.5 * logvar)
return mu, sigma
損失曲線
結果
CNNのみで構築したアーキテクチャよりもいい感じの結果が出ました。
アーキテクチャのパラメータ増加と画質の変化
AIに期待したいスケーリング能力を調べたい。ということでモデルのCNN+FCのモデルパラメーターを3倍にする。
すると学習がうまくいかなかった。
Nomarilze layerを追加すると
パラメータの数3倍、かつ正則化のためのアーキテクチャを追加しました。学習時間が20分と6分伸びましたが結果は以下です。勾配消失のような問題も起こらず、学習曲線も安定していました。
カラフルな画像の学習方法
今回はMNISTの手書きの数字の画像を学習させました。
これは白黒なので学習画像のサイズは$(Batch-size, channels, width, height) = (任意, 1, 28, 28)$となっています。
カラフルな画像の場合、(R,B,G)のパラメータが追加されるので、$Channels=3$と変更すれば学習できます。
もし、さらに画像生成の質を向上させるならば
以下の機能を追加すると思います。
- Self-Attention
=> 理由:画像の各要素同士の関係を学習できるため - Residiaul-connection
=>画像生成系のタスクでよく使われているため
いわゆる、U-netと呼ばれる構造になります。
6. 参考文献
今回利用した活性化関数についてSiLU(ReLUに似ており、マイナスの勾配も計算できることから採用) :
CNNによる次元拡大:
・ゼロから作るDeep Learning ❺:
VAEをゆるっと解説: