はじめに
PyTorchのVAEモデルを作成しました。ネットで検索した際に、MNISTを用いたものは多くありましたが、その他のデータセットを用いたものは少なかったので記事にしてみました。
理論的な側面よりも実装することを重視して記事にしたので間違いや不正確な部分がある点はご了承ください。また、コメントで指摘していただけると大変ありがたいです。
実装内容
PyTorchのVAEモデルを使って数字とひらがなの画像データを入力し、出力として画像を生成することを目指す。
ソースコード
動いているところがすぐに見たい方はこちらをどうぞ
Google Colab
データの作成
ここは後の処理がしやすいように、MNISTの画像をPNG形式に変換しました。これを数字データとして扱います。ひらがなデータはここのサイトからお借りしたデータを用いました。
# MNISTデータをPNGで保存
import os
from torchvision import datasets
# 保存先フォルダ設定
rootdir = "MNIST"
traindir = rootdir + "/train"
testdir = rootdir + "/test"
# MNIST データセット読み込み
train_dataset = datasets.MNIST(root=rootdir, train=True, download=True)
test_dataset = datasets.MNIST(root=rootdir, train=False, download=True)
# 画像保存 train
number = 0
for img, label in train_dataset:
savedir = traindir + "/" + str(label)
os.makedirs(savedir, exist_ok=True)
savepath = savedir + "/" + str(number).zfill(5) + ".png"
img.save(savepath)
number = number + 1
print(savepath)
# 画像保存 test
number = 0
for img, label in test_dataset:
savedir = testdir + "/" + str(label)
os.makedirs(savedir, exist_ok=True)
savepath = savedir + "/" + str(number).zfill(5) + ".png"
img.save(savepath)
number = number + 1
print(savepath)
wget http://lab.ndl.go.jp/dataset/hiragana73.zip
unzip ./hiragana73.zip
ライブラリのインポート
後々使うライブラリをインポートしておきます。どこで何が使われているかは、ソースコードを参照してください。
from torchvision import datasets, transforms
import torch
import os
from PIL import Image,ImageOps
import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch import optim
データセット、データローダーの実装
データセットとデータローダーの実装です。ここがMNISTデータセットのみを使用する場合と大きく異なります。MNISTデータセットを使用する場合は、torchvisionを用いれば簡単に行えますが、今回はひらがなも含んでいるのでデータセットを自作します。
DataSetクラスを作成し、初期化時はファイルのリンクをself.data配列に渡しておきます。初期化時に画像データを配列に渡してしまうと、莫大なメモリが必要になりそうです。(実験してませんが・・)
データが必要になったときにself.dataにあるリンクからPILを用いて読み込みリサイズしてからテンソル化しておきます。
getitem関数のif節はMNISTが背景黒で白字に対して、ひらがなが背景白で黒字であるため合わせるために、MNISTの画像を白黒反転させてあります。
あと、ひらがなデータセットに対してMNISTデータセットの数が多いのでMNISTのデータを使う枚数を制限してあります。(数字の過学習を防ぐため)
folder_list_hiragana = (os.listdir(path='.\\hiragana73'))
folder_list_MNIST = (os.listdir(path='.\\MNIST\\train'))
class DataSet:
def __init__(self):
self.data = []
for folder in folder_list_hiragana:
file_list = (os.listdir(path='.\\hiragana73\\'+folder))
for file in file_list:
self.data.append('.\\hiragana73\\'+folder+'\\'+file)
for folder in folder_list_MNIST:
file_list = (os.listdir(path='.\\MNIST\\train\\'+folder))
for i,file in enumerate(file_list):
self.data.append('.\\MNIST\\train\\'+folder+'\\'+file)
if i > 1000:
break
def __len__(self):
return len(self.data)
def __getitem__(self,index):
file_path = self.data[index]
image = Image.open(file_path)
if "MNIST" in file_path:
image = ImageOps.invert(image)
image = image.resize((28,28))
image = image.convert('L')
image = np.array(image)
image = torchvision.transforms.functional.to_tensor(image)
return image
dataset_train = DataSet()
dataset_valid = DataSet()
dataloader_train = torch.utils.data.DataLoader(dataset_train,batch_size=1000,shuffle=True,num_workers=0)
dataloader_valid = torch.utils.data.DataLoader(dataset_valid,batch_size=1000,shuffle=True,num_workers=0)
デバイス定義
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device:', device)
ネットワーク定義
ネットワーク定義はこちらのサイトを参考に作成しました。
VAEなので潜在変数zをサンプリングする層(エンコーダー)と、潜在変数zから復元する層(デコーダー)に分かれています。
損失関数のデルタはlog関数を適用するときに、値が0に近いと損失関数の値が無限に発散してしまって学習できなくなるので微小量加算することで損失関数が発散することを防いでおきます。
class VAE(nn.Module):
def __init__(self, z_dim):
super(VAE, self).__init__()
self.dense_enc1 = nn.Linear((28*28), 400)
self.dense_enc2 = nn.Linear(400, 400)
self.dense_encmean = nn.Linear(400, z_dim)
self.dense_encvar = nn.Linear(400, z_dim)
self.dense_dec1 = nn.Linear(z_dim, 400)
self.dense_dec2 = nn.Linear(400, 400)
self.dense_dec3 = nn.Linear(400, 28*28)
def _encoder(self, x):
x = F.relu(self.dense_enc1(x))
x = F.relu(self.dense_enc2(x))
mean = self.dense_encmean(x)
var = F.softplus(self.dense_encvar(x))
return mean, var
def _sample_z(self, mean, var):
epsilon = torch.randn(mean.shape).to(device)
return mean + torch.sqrt(var) * epsilon
def _decoder(self, z):
x = F.relu(self.dense_dec1(z))
x = F.relu(self.dense_dec2(x))
x = torch.sigmoid(self.dense_dec3(x))
return x
def forward(self, x):
mean, var = self._encoder(x)
z = self._sample_z(mean, var)
x = self._decoder(z)
return x, z
def loss(self, x):
delta = 1e-7
mean, var = self._encoder(x)
KL = -0.5 * torch.mean(torch.sum(1 + torch.log(var + delta) - mean**2 - var))
z = self._sample_z(mean, var)
y = self._decoder(z)
reconstruction = torch.mean(torch.sum(x * torch.log(y + delta) + (1 - x) * torch.log(1 - y + delta)))
lower_bound = [-KL, reconstruction]
return -sum(lower_bound)
メイン関数
一通り前処理は終わったので学習させるためのメインループを作成します。
二次元画像を一次元に変換してモデルに入力することだけ注意すればよさそうです。
model = VAE(50).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()
for i in range(50):
losses = []
for x in dataloader_train:
x = x.to(device)
x = x.view(-1,28*28)
model.zero_grad()
y = model(x)
loss = model.loss(x)
loss.backward()
optimizer.step()
losses.append(loss.cpu().detach().numpy())
print("EPOCH: {} loss: {}".format(i, np.average(losses)))
torch.save(model.state_dict(),"model.pth")
結果を出力してみる
import matplotlib.pyplot as plt
import numpy as np
fig = plt.figure(figsize=(10, 3))
# model_path = 'model.pth'
# model.load_state_dict(torch.load(model_path))
model.eval()
zs = []
for x in dataloader_valid:
for i, im in enumerate(x.view(-1, 28, 28).detach().numpy()[:10]):
ax = fig.add_subplot(3, 10, i+1, xticks=[], yticks=[])
ax.imshow(im, 'gray')
x = x.to(device)
x = x.view(-1,28*28)
y, z = model(x)
zs.append(z)
y = y.view(-1, 28, 28)
for i, im in enumerate(y.cpu().detach().numpy()[:10]):
ax = fig.add_subplot(3, 10, i+11, xticks=[], yticks=[])
ax.imshow(im, 'gray')
結果は以下のようになりました。上が入力画像で下が出力画像です。
入力画像を次元削減して圧縮した後、復元してもそれなりに読める字になっていると思います。
参考にさせていただいたサイト
PyTorchでVAEのモデルを実装してMNISTの画像を生成する
https://www.sambaiz.net/article/212/
文字画像データセット(平仮名73文字版)
https://github.com/ndl-lab/hiragana_mojigazo/blob/master/readme.md
Variational Autoencoder徹底解説
https://qiita.com/kenmatsu4/items/b029d697e9995d93aa24
Pytorch - 自作のデータセットを扱う Dataset クラスを作る方法 - pystyle
https://pystyle.info/pytorch-how-to-create-custom-dataset-class/