UnetとVAEを組み合わせたモデルについて
解決したいこと
エンコード/デコード部分をUnetにし、エンコーダーとデコーダーの接続箇所をVAEの方式(正規分布の特徴空間に落とし込む)に変更して学習してみました。
しかし損失が発散してしまい、モデルを軽量化してみても改善しませんでした。
そこで、そもそもコード自体おかしい箇所が無いか、ここをこうした方が良いなどご意見を頂けると大変ありがたいです。
よろしくお願い致します。
※エンコーダーデコーダーで、本件のような構造を取らなければならないことはないです。
しかし、まだまだ私の経験が足りないところもあり、今回のような際にはどんなアプローチが良いのか、知見のある方々に質問したく投稿させて頂いております。
発生している問題・エラー
下記に該当コードを記載します。
import torch
import torch.nn as nn
import torch.nn.functional as F
# PyTorch画像用
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
class Conv3(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, is_res: bool = False
) -> None:
super().__init__()
self.main = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
nn.GroupNorm(8, out_channels),
nn.ReLU(),
)
self.conv = nn.Sequential(
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
nn.GroupNorm(8, out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
nn.GroupNorm(8, out_channels),
nn.ReLU(),
)
self.is_res = is_res
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.main(x)
if self.is_res:
x = x + self.conv(x)
return x / 1.414
else:
return self.conv(x)
class UnetDown(nn.Module):
def __init__(self, in_channels: int, out_channels: int) -> None:
super(UnetDown, self).__init__()
layers = [Conv3(in_channels, out_channels), nn.MaxPool2d(2)]
self.model = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
class UnetUp(nn.Module):
def __init__(self, in_channels: int, out_channels: int) -> None:
super(UnetUp, self).__init__()
layers = [
nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
Conv3(out_channels, out_channels),
Conv3(out_channels, out_channels),
]
self.model = nn.Sequential(*layers)
def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
x = torch.cat((x, skip), 1)
x = self.model(x)
return x
class Encoder(nn.Module):
def __init__(self, in_channels, out_channels, latent_dim, n_feat):
super(Encoder, self).__init__()
# Encoder layers
self.in_channels = in_channels
self.out_channels = out_channels
self.n_feat = n_feat
self.init_conv = Conv3(in_channels, n_feat, is_res=True)
self.down1 = UnetDown(n_feat, 2 * n_feat)
# self.down2 = UnetDown(n_feat, 2 * n_feat)
# self.down3 = UnetDown(2 * n_feat, 2 * n_feat)
self.fc1 = nn.Linear(2 * n_feat * 64 * 64, 256)
self.fc2_mean = nn.Linear(256, latent_dim)
self.fc2_logvar = nn.Linear(256, latent_dim)
def forward(self, x):
# ニューラルネットワークで事後分布の平均・分散を計算する
print("input", x.shape)
init_x = self.init_conv(x)
print("init", init_x.shape)
down1 = self.down1(init_x)
print("down1", down1.shape)
# down2 = self.down2(down1)
# print("down2", down2.shape)
# down3 = self.down3(down1)
# print("down3", down3.shape)
down1_reshape = down1.view(-1, 2 * n_feat * 64 * 64)
print("down1_reshape.view", down1_reshape.shape)
h = F.relu(self.fc1(down1_reshape))
print("fc1", h.shape)
mean = self.fc2_mean(h) # μ
print("mean", mean.shape)
var = self.fc2_logvar(h) # s
print("var", var.shape)
var = F.softplus(var)
# # 潜在変数を求める
# ## 標準正規乱数を振る
# eps = torch.randn_like(torch.exp(mu))
# ## 潜在変数の計算 μ + σ・ε
# z = mu + torch.exp(log_var / 2) * eps
return mean, var, down1
class Decoder(nn.Module):
def __init__(self, in_channels, out_channels, latent_dim, n_feat):
super(Decoder, self).__init__()
# Decoder layers
self.in_channels = in_channels
self.out_channels = out_channels
self.n_feat = n_feat
self.fc3 = nn.Linear(latent_dim, 256)
self.fc4 = nn.Linear(256, 2 * n_feat * 64 * 64)
self.up0 = nn.Sequential(
nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 4, 4),
nn.GroupNorm(8, 2 * n_feat),
nn.ReLU(),
)
self.up1 = UnetUp(4 * n_feat, 2 * n_feat)
# self.up2 = UnetUp(4 * n_feat, n_feat)
self.up3 = UnetUp(4 * n_feat, 2 * n_feat)
self.out = nn.Conv2d(2 * n_feat, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, z, down1):
print("decode-input", z.shape)
x = F.relu(self.fc3(z))
print("decode-fc3", x.shape)
x = F.relu(self.fc4(x))
print("decode-fc4", x.shape)
x = x.view(-1, 2 * n_feat , 64, 64)
print("x.view", x.shape)
# up0 = self.up0(x)
# print("up0:", up0.shape)
print("x:", x.shape)
# print("down3:", down3.shape)
# up1 = self.up1(x, down3)
# print("up1:", up1.shape)
# print("down1:", down1.shape)
# up2 = self.up2(up1, down2)
# print("up2:", up2.shape)
up3 = self.up3(x, down1)
print("up3:", up3.shape)
out = torch.sigmoid(self.out(up3))
print("out:", out.shape)
return out
# VAE-Unetモデル
class VAE(nn.Module):
def __init__(self, in_channels, out_channels, latent_dim, n_feat):
super(VAE, self).__init__()
self.encoder = Encoder(in_channels, out_channels, latent_dim, n_feat)
self.decoder = Decoder(in_channels, out_channels, latent_dim, n_feat)
def forward(self, x):
mean, var, down1= self.encoder(x)
z = self.latent_variable(mean, var)
y = self.decoder(z, down1)
return z, y
def latent_variable(self, mean, var):
eps = torch.randn(mean.size())
# モデル定義時にgpuに渡しているが、何故かここでエラーが生じるのでepsをgpuに渡している
eps = eps.to(device)
z = mean + torch.sqrt(var)*eps
return z
def loss(self, x):
mean, var, down1= self.encoder(x)
z = self.latent_variable(mean, var)
y = self.decoder(z, down1)
reconst_loss = -torch.mean(torch.sum(x*torch.log(y) + (1 - x)* torch.log(1 - y), dim=1))
latent_loss = - 1/2 * torch.mean(torch.sum(1 + torch.log(var) - mean**2 - var, dim=1))
loss = reconst_loss + latent_loss
return loss
def generate_images(self, x, device):
mean, var, down1 = self.encoder(x)
z = self.latent_variable(mean, var)
y = self.decoder(z, down1)
generated_images = self.decoder(z, down1)
return generated_images
import torch
print(torch.cuda.is_available())
print(torch.__version__ )
torch.cuda.current_device()
from PIL import Image
import os
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from pathlib import Path
from multiprocessing import Pool, freeze_support, RLock
import numpy as np
class CatDataset(Dataset):
def __init__(self, path):
files = os.listdir(path)
self.file_list = [os.path.join(path,file) for file in files]
self.transform = transforms.Compose(
[
transforms.Resize(128),
transforms.RandomHorizontalFlip(0.5),
transforms.ToTensor(),
# 0~1を-1~1に変換
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]
)
def __len__(self):
return len(self.file_list)
def __getitem__(self, i):
img = Image.open(self.file_list[i])
return self.transform(img)
if __name__ == '__main__':
freeze_support()
num_epochs = 200
batch_size = 32
learning_rate = 1e-5
train_path = Path("./train")
dataset_dir = train_path
test_path = Path("./test")
dataset_dir_test = test_path
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
# モデルのインスタンス化
latent_dim = 64
model = VAE(3, 3, latent_dim, n_feat).to(device)
loss_function = model.loss
# model = VAE(image_size, h1_dim, h2_dim, z_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
dataset = CatDataset(dataset_dir)
dataset_test = CatDataset(dataset_dir_test)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
dataloader_test = DataLoader(dataset_test, batch_size=8, shuffle=True, num_workers=4)
# train
losses = []
for epoch in range(num_epochs):
print(f"Epoch {epoch} : ")
train_loss = 0
pbar = tqdm(dataloader)
for i, x in enumerate(pbar):
print(x.shape)
# 予測
x = x.to(device)
model.train()
loss = loss_function(x)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
print("train_loss", train_loss)
train_loss /= len(dataloader)
print('Epoch({}) -- loss: {:.3f}'.format(epoch+1, train_loss))
losses.append(train_loss)
# モデルの保存
torch.save(model.state_dict(), f"./vae2/vae_model_epoch{epoch+1}.pt")
with torch.no_grad():
for i, x in enumerate(dataloader_test):
# 画像の生成と保存
x = x.to(device)
generated_images = model.generate_images(x, device=device)
save_image(generated_images, f"./vae_images2/generated_images_epoch{epoch+1}.png", nrow=4)
自分で試したこと
コメントアウトしていますが、エンコーダー、デコーダー共に、down,upの回数を減らしてモデルを軽量化してみましたが、lossの推移を見る限りはほぼ影響が無かったです。
あとは、latent_dimを小さくしたり画像サイズを小さくしてみましたが、効果は見られませんでした。。。
よろしくお願い致します。
0