0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【Pytorch】VAEの実装

Posted at

1. VAEの概要

1.1 VAEとは

2014年に以下の論文で発表された「画像を生成する生成モデル」

Auto-Encoding Variational Bayes

元論文

2. PytorchでVAEを実装する

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F

# Pytorch画像用
import torchvision
import torchvision.transforms as transforms

# 画像表示用
import matplotlib.pyplot as plt

batch_size = 128

# データセットの取得
train_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    transform=transforms.ToTensor(),
    download=True,
)
# データローダーの作成
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0
)

# エンコーダ
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        self.fc = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_var = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x):
        #ニューラルネットワークで事後分布の平均・分散を計算する
        h = torch.relu(self.fc(x))
        mu = self.fc_mu(h) # μ
        log_var = self.fc_var(h) # log σ^2

        # 潜在変数を求める
        ## 標準正規分布を振る
        eps = torch.randn_like(torch.exp(log_var))
        ## 潜在変数の計算 μ + σ・ε
        z = mu + torch.exp(log_var / 2) * eps
        return mu, log_var, z

# デコーダ
class Decoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim, hidden_dim)
        self.fc_output = nn.Linear(hidden_dim, input_dim)

    def forward(self, z):
        h = torch.relu(self.fc(z))
        output = torch.sigmoid(self.fc_output(h))
        return output

# VAE全体
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
        self.decoder = Decoder(input_dim, hidden_dim, latent_dim)
    
    def forward(self, x):
        mu, log_var, z = self.encoder(x) # エンコーダ
        x_decoded = self.decoder(z) # デコード
        return x_decoded, mu, log_var, z

# 学習コード
def loss_function(label, predict, mu, log_var):
    reconstruction_loss = F.binary_cross_entropy(predict, label, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    vae_loss = reconstruction_loss + kl_loss
    return vae_loss, reconstruction_loss, kl_loss

# ハイパーパラメータを設定
image_size = 28 * 28
h_dim = 32
z_dim = 16
num_epochs = 10
learning_rate = 1e-3

# デバイスの設定
## GPUが使える場合はGPUを使う
## torch.cuda.is_available() は、CUDA(GPU)が利用可能かどうかを確認する
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
##.to(device) を使って、モデル VAE を指定したデバイス(CPU または GPU)に転送する
model = VAE(image_size, h_dim, z_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# 学習
## 予測→損失計算→パラメータ更新→損失の表示を繰り返す
losses = []
model.train()
for epoch in range(num_epochs):
    train_loss = 0
    for i , (x, labels) in enumerate(train_loader):
        # 予測
        x = x.to(device).view(-1, image_size).to(torch.float32)
        x_recon, mu, log_var, z = model(x)
        # 損失計算の計算
        loss, recon_loss, kl_loss = loss_function(x, x_recon, mu, log_var)

        # パラメータの更新
        optimizer.zero_grad()
        # 誤差逆伝播
        loss.backward()
        optimizer.step()

        # 損失の表示
        if (i+1) % 10 == 0:
            print(f'Epoch: {epoch+1}, loss: {loss: 0.4f}, reconstruct loss: {recon_loss: 0.4f}, kl loss: {kl_loss: 0.4f}')
        losses.append(loss)

# 画像の生成
model.eval()

with torch.no_grad():
    z = torch.randn(25, z_dim).to(device)
    out = model.decoder(z)
out = out.view(-1, 28, 28)
out = out.cpu().detach().numpy()

# 画像の表示
fig, ax = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))
## グレースケールで表示
plt.gray()
for i in range(25):
    idx = divmod(i, 5)
    ax[idx].imshow(out[i])
    ax[idx].axis('off');
## fig.savefig() を呼び出す際には、保存するファイル名を必ず引数として指定
fig.savefig("01.png")

3. 生成された画像

01.png

4. まとめ

今回はVAEをPyTorchで実装した。

では!

参考にしたサイト

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?