1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

VAE(Variational Autoencoder)をPyTorchで動かしてみた

1
Last updated at Posted at 2026-05-10

はじめに

World Modelsを理解してみたいと思い、勉強を進めてきたのですが、その中でVAE(Variational Autoencoder)の知識が必要になったので、アウトプットの意味も込め、自分なりにわかりやすく解説してみようと思います。
ただ理論から一つ一つ説明するのはかなり大変なので、この記事では、
"VAEを軽く説明したあと、実際に動かして実験してみよう" と思います。
具体的には、以下の3つの実験を行います。

  • 実験1: VAEはちゃんと画像を復元できるの?
  • 実験2: 潜在変数zって数字ごとに分布に法則性があるの?
  • 実験3: 可視化された潜在空間の通りに数字は生成されるのか?

なお、実装コードはこちら(https://github.com/origamider/machine-learning)

VAEを大雑把に説明

VAEは入力データxから潜在変数zを学ぶモデルです。
潜在変数zは重要な特徴が詰まったやつね。
VAEでは、あらかじめ、「潜在変数zは平均0,分散$I$の正規分布に従っている」と仮定しています。
なお、似たようなものとして、AE(Autoencoder)というものがあります。
VAEはEncoderでμ,σを求めてから潜在変数zを正規分布で求めますが、
AEはEncoderから直接潜在変数zを求める、という風に理解しています。

スクリーンショット 2026-05-10 15.37.28.png

スクリーンショット 2026-05-10 15.41.52.png

実装準備

ここからはコードを動かしていきます。まずは準備から。

モジュールimport

import torch
import torch.nn as nn
import numpy as np
import torchvision
from torchvision import datasets, transforms
import torch.optim as optim
import matplotlib.pyplot as plt
import japanize_matplotlib

必要なモジュールをimportする。

VAE定義

class Encoder(nn.Module):
    def __init__(self,input_dim,hidden_dim,latent_dim):
        super().__init__()
        self.layer = nn.Linear(input_dim,hidden_dim)
        self.relu = nn.ReLU()
        self.softplus = nn.Softplus()
        self.layer_mu = nn.Linear(hidden_dim,latent_dim)
        self.layer_sigma = nn.Linear(hidden_dim,latent_dim)
    
    def forward(self, x):
        base = self.relu(self.layer(x))
        mu = self.layer_mu(base)
        sigma = self.softplus(self.layer_sigma(base))
        return mu, sigma

class Decoder(nn.Module):
    def __init__(self,latent_dim,hidden_dim,output_dim):
        super().__init__()
        self.l1 = nn.Linear(latent_dim,hidden_dim)
        self.l2 = nn.Linear(hidden_dim,output_dim)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self,z):
        tmp = self.relu(self.l1(z))
        x_hat = self.sigmoid(self.l2(tmp))
        return x_hat

def create_z(mu, sigma):
    epsilon = torch.randn_like(sigma)
    z = mu + epsilon * sigma
    return z
    
class VAE(nn.Module):
    def __init__(self,input_dim,hidden_dim,latent_dim):
        super().__init__()
        self.encoder = Encoder(input_dim,hidden_dim,latent_dim)
        self.decoder = Decoder(latent_dim,hidden_dim,input_dim)
        self.mseloss = nn.MSELoss(reduction="sum")
    
    def get_loss(self, x):
        mu, sigma = self.encoder(x)
        z = create_z(mu, sigma)
        x_hat = self.decoder(z)
        return (self.mseloss(x_hat,x) * 0.5 - torch.sum(1 + torch.log(sigma ** 2) - mu ** 2 - sigma ** 2) * 0.5) / len(x)

ハイパーパラメータの用意

#ハイパーパラメータ
input_dim = 784 #いじっちゃダメ。MNIST画像は28*28個の数字(0以上1以下?)でできている。
hidden_dim = 32 #いじってOK。エンコーダ、デコーダの隠れ層の次元。
latent_dim = 2  #いじってOK。潜在変数zの次元。
batch_size = 64 #いじってOK。機械学習時にbatch_size分取り出してミニバッチ学習。
num_epochs = 30 #いじってOK。機械学習回数。

input_dim以外は基本的に数値をいじってOK。

データセットの用意(MNIST)

#前処理。MNIST画像をTensor型に直し、形状(28,28)を(784)のように、1列にする。
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(torch.flatten)
])
# MNISTデータの使用(機械学習用)
dataset = datasets.MNIST(
    root = "./data/",
    train = True,
    download= True,
    transform=transform
)
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True
)

今回はMNISTデータを使用する。

学習

model = VAE(input_dim,hidden_dim,latent_dim)
optimizer = optim.Adam(model.parameters()) #最適化関数はAdamを使用。他にもSGDとかでもいけると思う。


# 学習
for epoch in range(num_epochs):
    loss_sum = 0
    ct = 0
    for x,label in dataloader:
        optimizer.zero_grad()
        loss = model.get_loss(x)
        loss.backward()
        optimizer.step()
        loss_sum += loss.item()
        ct += 1
    # lossの可視化
    loss_avg = loss_sum / ct
    print(loss_avg)

lossを出力する。学習するごとに下がっていくはず。

実験1:VAEはちゃんと画像を復元できるの?

# MNISTデータの使用(実験用)
test_dataset = datasets.MNIST(
    root = "./data/",
    train = False,
    download= True,
    transform=transform
)
# 実験1:MNIST画像と復元後の画像を比較してみよう。上がオリジナル、下が復元後の画像。
n = 8
x = torch.stack([test_dataset[i][0] for i in range(n)])
label = np.array([test_dataset[i][1] for i in range(n)])
mu, sigma = model.encoder(x)
z = create_z(mu, sigma)
generate_x = model.decoder(z)
fig, axes = plt.subplots(2,n,figsize=(8,4))
for i in range(n):
    axes[0][i].imshow(x[i].view(28,28).detach().numpy(),cmap="gray")
    axes[1][i].imshow(generate_x[i].view(28,28).detach().numpy(),cmap="gray")
    axes[0][i].axis("off")
    axes[1][i].axis("off")
plt.show()

結果はこちら。

スクリーンショット 2026-05-10 16.19.16.png

いい感じに復元できている。
まあよく見ると、2が3に見えたり、4が9に見えたり、みたいなびみょいやつあるけど。

実験2:潜在変数zって数字ごとに分布に法則性があるの?

まあ、これはいろんな記事とかで検証されていると思いますが、やってみます。
今回は2次元の潜在変数zを数字ごとにプロットします。

xs = torch.stack([test_dataset[i][0] for i in range(len(test_dataset))])
labels = np.array([test_dataset[i][1] for i in range(len(test_dataset))])

mu, sigma = model.encoder(xs)
zs = create_z(mu, sigma)
zs = zs.detach().numpy()
# 0~9までの数字に該当する潜在変数zを出力する
# 注意:latent_dim=2にしてね。
for num in range(10):
    index = (labels == num)
    plt.scatter(zs[index,0],zs[index,1],label=str(num))
plt.legend()
plt.xlabel("z1")
plt.ylabel("z2")
plt.title("潜在変数zを2次元上で表示")
plt.show()

結果はこちらです。

スクリーンショット 2026-05-10 16.26.36.png

見事に数字ごとに綺麗に分布していていますね。
このことから、たとえば、
青色付近の領域から適当に座標を選んで、それを潜在変数zとして、
decoderで復元画像x^に戻したら、なんと"0に近い画像"が生成されるということになります!
面白いな🧐

実験3:可視化された潜在空間の通りに数字は生成されるのか?

実験2の結果から、潜在変数zの分布には数字ごとの法則性が見えましたね。
そこで考えたのですが、法則性の通りにzを選んだら、本当にその位置に対応する数字が復元されるのかが気になったので、検証します。

2次元の潜在変数は-3<=$z_{1}$<=3,-3<=$z_{2}$<=3周辺に分布していることがわかったので、
この範囲を20×20の格子に区切り、各格子点をzとしてdecoderに渡して画像を生成しました。

# 実験3:潜在変数から復元画像xを求める
xs = torch.linspace(-3,3,20)
ys = torch.linspace(3,-3,20)
grid_y, grid_x = torch.meshgrid(ys,xs)
grid_x = torch.flatten(grid_x)
grid_y = torch.flatten(grid_y)
zs = torch.stack((grid_x,grid_y),dim=1)
x_hat = model.decoder(zs)
x_hat = x_hat.view(20,20,28,28).permute(0,2,1,3).reshape(560,560).detach().numpy()
plt.imshow(x_hat,cmap="gray")
plt.show()

結果はこちら。

スクリーンショット 2026-05-10 16.33.12.png

いい感じに数字ごとに配置されていますね。
実際に下の画像を見ると分かる通り、実験2の散布図と、実験3のグリッド図で、数字の位置が対応していることがわかります。

スクリーンショット 2026-05-10 22.21.04.png

最後にちょっぴり

最初はVAEの理論まで説明しようと思いましたが、あまりにも書くのが大変だったので、
実験メインになりました😭
ただ、すでにVAEの記事は豊富にあるので、ぜひ参考にしてみるといいでしょう。
自分もまだ勉強中の身なので、もし記事の中で間違いや誤解を招く表現があれば、コメントで教えていただけると嬉しいです😁

分かりやすい参考になる記事

  • 斎藤康毅『ゼロから作るDeep Learning ❺』オライリー・ジャパン, 2024(本記事の実装コードもこの書籍を参考にしました)

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?