自分の練習のため、ディープラーニングの一種であるCVAEを実装して学習させてみました。
本記事はメモ書きレベルの記述であり、VAEに関する知識があることを前提として書かれています。ご了承ください。
環境
- OS: Windows10
- Python: 3.7.5
- CUDA: 9.2
- numpy: 1.18.1
- torch: 1.4.0+cu92
- torchvision: 0.5.0+cu92
- matplotlib: 3.1.3
また、Jupyter Notebookを用いて実装しています
#参考記事
実装にあたって、参考にしたページを紹介します。
その他、Pytorchのexample実装も参考にしています。
#CVAEとは
**CVAE(Conditional Variational AutoEncoder)**はVAEの発展手法です。
通常のVAEでは、Encoderにデータを、Decoderに潜在変数を入力しますが、CVAEではこれらにデータの状態(Condition)を付加させます。これにより、次のメリットを得ることができます。
- Encoderで次元削除するとき、データのラベル以外の特徴を反映させることができる
- Decoderでデータ生成するとき、欲しいデータの状態を指定することができる
#実装と学習
今回、PytorchでCVAEを実装し、MNIST(手書き文字のデータセット)を学習させます。
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
%matplotlib inline
DEVICE = 'cuda'
SEED = 0
CLASS_SIZE = 10
BATCH_SIZE = 256
ZDIM = 16
NUM_EPOCHS = 50
# Set seeds
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
class CVAE(nn.Module):
def __init__(self, zdim):
super().__init__()
self._zdim = zdim
self._in_units = 28 * 28
hidden_units = 512
self._encoder = nn.Sequential(
nn.Linear(self._in_units + CLASS_SIZE, hidden_units),
nn.ReLU(inplace=True),
nn.Linear(hidden_units, hidden_units),
nn.ReLU(inplace=True),
)
self._to_mean = nn.Linear(hidden_units, zdim)
self._to_lnvar = nn.Linear(hidden_units, zdim)
self._decoder = nn.Sequential(
nn.Linear(zdim + CLASS_SIZE, hidden_units),
nn.ReLU(inplace=True),
nn.Linear(hidden_units, hidden_units),
nn.ReLU(inplace=True),
nn.Linear(hidden_units, self._in_units),
nn.Sigmoid()
)
def encode(self, x, labels):
in_ = torch.empty((x.shape[0], self._in_units + CLASS_SIZE), device=DEVICE)
in_[:, :self._in_units] = x
in_[:, self._in_units:] = labels
h = self._encoder(in_)
mean = self._to_mean(h)
lnvar = self._to_lnvar(h)
return mean, lnvar
def decode(self, z, labels):
in_ = torch.empty((z.shape[0], self._zdim + CLASS_SIZE), device=DEVICE)
in_[:, :self._zdim] = z
in_[:, self._zdim:] = labels
return self._decoder(in_)
def to_onehot(label):
return torch.eye(CLASS_SIZE, device=DEVICE, dtype=torch.float32)[label]
# Train
train_dataset = torchvision.datasets.MNIST(
root='./data',
train=True,
transform=transforms.ToTensor(),
download=True,
)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0
)
model = CVAE(ZDIM).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
model.train()
for e in range(NUM_EPOCHS):
train_loss = 0
for i, (images, labels) in enumerate(train_loader):
labels = to_onehot(labels)
# Reconstruction images
# Encode images
x = images.view(-1, 28*28*1).to(DEVICE)
mean, lnvar = model.encode(x, labels)
std = lnvar.exp().sqrt()
epsilon = torch.randn(ZDIM, device=DEVICE)
# Decode latent variables
z = mean + std * epsilon
y = model.decode(z, labels)
# Compute loss
kld = 0.5 * (1 + lnvar - mean.pow(2) - lnvar.exp()).sum(axis=1)
bce = F.binary_cross_entropy(y, x, reduction='none').sum(axis=1)
loss = (-1 * kld + bce).mean()
# Update model
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item() * x.shape[0]
print(f'epoch: {e + 1} epoch_loss: {train_loss/len(train_dataset)}')
結果
epoch: 1 epoch_loss: 200.2185905436198
epoch: 2 epoch_loss: 160.22688263346353
epoch: 3 epoch_loss: 148.69330817057292
# 中略
epoch: 48 epoch_loss: 98.95304524739583
epoch: 49 epoch_loss: 98.6720672281901
epoch: 50 epoch_loss: 98.65486107177735
以下、実装と学習の要点を列挙します
- 学習に
torchvision.datasets.MNIST
の訓練用データ6000枚を利用し、エポック数を50とする - EncoderとDecoderを持つCVAEクラスを設計し、
forward
は実装せずencode
及びdecode
メソッドを実装する - データセットのラベル(書かれている数字)をone-hotなベクトルに変換し、EncoderとDecoderの入力に付加する
- 学習時のミニバッチサイズは大きめの256とする1
- Encoder、Decoderともに単純なMLPで構成する
- 潜在変数の次元を16とする
#CVAEによる画像生成
VAEは次元削除とデータ生成の2つの応用が可能ですが、今回はデータ生成に焦点をあてます。
先程学習させたCVAEのDecoderを用いて、手書き画像を新しく生成することを考えます。
###「5」の画像の生成
Decoderに与えるラベルの情報を「5」に固定し、標準正規分布に従う乱数を100個生成して、それぞれに対応する画像を生成します。
# Generation data with label '5'
NUM_GENERATION = 100
os.makedirs(f'img/cvae/generation/label5/', exist_ok=True)
model.eval()
for i in range(NUM_GENERATION):
z = torch.randn(ZDIM, device=DEVICE).unsqueeze(dim=0)
label = torch.tensor([5], device=DEVICE)
with torch.no_grad():
y = model.decode(z, to_onehot(label))
y = y.reshape(28, 28).cpu().detach().numpy()
# Save image
fig, ax = plt.subplots()
ax.imshow(y)
ax.set_title(f'Generation(label={label.cpu().detach().numpy()[0]})')
ax.tick_params(
labelbottom=False,
labelleft=False,
bottom=False,
left=False,
)
plt.savefig(f'img/cvae/generation/label5/img{i + 1}')
plt.close(fig)
結果
形が崩れているものもありますが、様々な「5」の画像を生成できています。
###太い数字画像の生成
torchvision.datasets.MNIST
のテスト用画像から、太く書かれた数字を探しました。
次の画像は、データセットの49番目の画像です。
非常に太く「4」と書かれていますね。
このデータに対応する潜在変数をEncoderで求ます。
test_dataset = torchvision.datasets.MNIST(
root='./data',
train=False,
transform=transforms.ToTensor(),
download=True,
)
target_image, label = list(test_dataset)[48]
x = target_image.view(1, 28*28).to(DEVICE)
with torch.no_grad():
mean, _ = model.encode(x, to_onehot(label))
z = mean
print(f'z = {z.cpu().detach().numpy().squeeze()}')
結果
z = [ 0.7933388 2.4768877 0.49229255 -0.09540698 -1.7999544 0.03376897
0.01600834 1.3863252 0.14656337 -0.14543885 0.04157912 0.13938689
-0.2016176 0.5204378 -0.08096244 1.0930295 ]
この16次元のベクトルは、学習時に与えているラベル以外の画像の情報を持っています。つまり、「4の形である」という情報ではなく、「非常に太い」という情報を持っているはずです。
そこで、この潜在変数を用いて、Decoderに与えるラベルの情報を変えながら画像を生成してみます。
os.makedirs(f'img/cvae/generation/fat', exist_ok=True)
for label in range(CLASS_SIZE):
with torch.no_grad():
y = model.decode(z, to_onehot(label))
y = y.reshape(28, 28).cpu().detach().numpy()
fig, ax = plt.subplots()
ax.imshow(y)
ax.set_title(f'Generation(label={label})')
ax.tick_params(
labelbottom=False,
labelleft=False,
bottom=False,
left=False,
)
plt.savefig(f'img/cvae/generation/fat/img{label}')
plt.close(fig)
結果
「2」がちょっと怪しいですが、太い数字の画像を生成できています。
#おわりに
CVAEについて、随分前から知識だけは知っていたのですが、実装したのは今回が初めてでした。うまくいったので良かったです。知識だけでなく実装してみるのは大事ですね。
生成した画像の中にはきれいな数字にならなかったものもありましたが、VAEのネットワークに畳み込みや転置畳み込みを用いることで解消されるかもしれません。
今回は割愛しましたが、VAE系はどの特徴が低次元空間上のどこにマップされているのか分析することが重要だと認識しています。今度はその分析もやってみたいです。
-
ミニバッチ内に全てのラベルを持つデータが存在するようにして、Encoderによるミニバッチの像が潜在変数空間上で標準正規分布に従うようにするためです。 ↩