megane-9mm
@megane-9mm

Are you sure you want to delete the question?

Leaving a resolved question undeleted may help others!

UnetとVAEを組み合わせたモデルについて

Discussion

Closed

解決したいこと

エンコード/デコード部分をUnetにし、エンコーダーとデコーダーの接続箇所をVAEの方式(正規分布の特徴空間に落とし込む)に変更して学習してみました。
しかし損失が発散してしまい、モデルを軽量化してみても改善しませんでした。

そこで、そもそもコード自体おかしい箇所が無いか、ここをこうした方が良いなどご意見を頂けると大変ありがたいです。
よろしくお願い致します。

※エンコーダーデコーダーで、本件のような構造を取らなければならないことはないです。
しかし、まだまだ私の経験が足りないところもあり、今回のような際にはどんなアプローチが良いのか、知見のある方々に質問したく投稿させて頂いております。

発生している問題・エラー

下記に該当コードを記載します。

import torch
import torch.nn as nn
import torch.nn.functional as F
 
# PyTorch画像用
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image


class Conv3(nn.Module):
    def __init__(
        self, in_channels: int, out_channels: int, is_res: bool = False
    ) -> None:
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.GroupNorm(8, out_channels),
            nn.ReLU(),
        )
        self.conv = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.GroupNorm(8, out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.GroupNorm(8, out_channels),
            nn.ReLU(),
        )

        self.is_res = is_res

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.main(x)
        if self.is_res:
            x = x + self.conv(x)
            return x / 1.414
        else:
            return self.conv(x)


class UnetDown(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super(UnetDown, self).__init__()
        layers = [Conv3(in_channels, out_channels), nn.MaxPool2d(2)]
        self.model = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        return self.model(x)


class UnetUp(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super(UnetUp, self).__init__()
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
            Conv3(out_channels, out_channels),
            Conv3(out_channels, out_channels),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
        x = torch.cat((x, skip), 1)
        x = self.model(x)

        return x


class Encoder(nn.Module):
  def __init__(self, in_channels, out_channels, latent_dim, n_feat):
    super(Encoder, self).__init__()
    # Encoder layers
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.n_feat = n_feat

    self.init_conv = Conv3(in_channels, n_feat, is_res=True)
    self.down1 = UnetDown(n_feat, 2 * n_feat)
    # self.down2 = UnetDown(n_feat, 2 * n_feat)
    # self.down3 = UnetDown(2 * n_feat, 2 * n_feat)
    self.fc1 = nn.Linear(2 * n_feat * 64 * 64, 256)
    self.fc2_mean = nn.Linear(256, latent_dim)
    self.fc2_logvar = nn.Linear(256, latent_dim)
  
  def forward(self, x):
    # ニューラルネットワークで事後分布の平均・分散を計算する
    print("input", x.shape)
    init_x = self.init_conv(x)
    print("init", init_x.shape)
    down1 = self.down1(init_x)
    print("down1", down1.shape)
    # down2 = self.down2(down1)
    # print("down2", down2.shape)
    # down3 = self.down3(down1)
    # print("down3", down3.shape)
    down1_reshape = down1.view(-1, 2 * n_feat * 64 * 64)
    print("down1_reshape.view", down1_reshape.shape)
    h = F.relu(self.fc1(down1_reshape))
    print("fc1", h.shape)
    mean = self.fc2_mean(h) # μ
    print("mean", mean.shape)
    var = self.fc2_logvar(h) # s
    print("var", var.shape)
    var = F.softplus(var)
 
    # # 潜在変数を求める
    # ## 標準正規乱数を振る
    # eps = torch.randn_like(torch.exp(mu))
    # ## 潜在変数の計算 μ + σ・ε
    # z = mu + torch.exp(log_var / 2) * eps
    return mean, var, down1


class Decoder(nn.Module):
  def __init__(self, in_channels, out_channels, latent_dim, n_feat):
    super(Decoder, self).__init__()
    # Decoder layers
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.n_feat = n_feat

    self.fc3 = nn.Linear(latent_dim, 256)
    self.fc4 = nn.Linear(256, 2 * n_feat * 64 * 64)
    self.up0 = nn.Sequential(
            nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 4, 4),
            nn.GroupNorm(8, 2 * n_feat),
            nn.ReLU(),
        )

    self.up1 = UnetUp(4 * n_feat, 2 * n_feat)
    # self.up2 = UnetUp(4 * n_feat, n_feat)
    self.up3 = UnetUp(4 * n_feat, 2 * n_feat)
    self.out = nn.Conv2d(2 * n_feat, out_channels, kernel_size=1, stride=1, padding=0)
     
  def forward(self, z, down1):
      print("decode-input", z.shape)
      x = F.relu(self.fc3(z))
      print("decode-fc3", x.shape)
      x = F.relu(self.fc4(x))
      print("decode-fc4", x.shape)
      x = x.view(-1, 2 * n_feat , 64, 64)
      print("x.view", x.shape)
      # up0 = self.up0(x)
      # print("up0:", up0.shape)
      print("x:", x.shape)
      # print("down3:", down3.shape)
      # up1 = self.up1(x, down3) 
      # print("up1:", up1.shape)
      # print("down1:", down1.shape)
      # up2 = self.up2(up1, down2)
      # print("up2:", up2.shape)
      up3 = self.up3(x, down1)
      print("up3:", up3.shape)
      out = torch.sigmoid(self.out(up3))
      print("out:", out.shape)
      return out


# VAE-Unetモデル
class VAE(nn.Module):
  def __init__(self, in_channels, out_channels, latent_dim, n_feat):
    super(VAE, self).__init__()
    self.encoder = Encoder(in_channels, out_channels, latent_dim, n_feat)
    self.decoder = Decoder(in_channels, out_channels, latent_dim, n_feat)
    
  def forward(self, x):
    mean, var, down1= self.encoder(x)
    z = self.latent_variable(mean, var)
    y = self.decoder(z, down1)
    return z, y

  def latent_variable(self, mean, var):
    eps = torch.randn(mean.size())
    # モデル定義時にgpuに渡しているが、何故かここでエラーが生じるのでepsをgpuに渡している
    eps = eps.to(device)
    z = mean + torch.sqrt(var)*eps
    return z

  def loss(self, x):
    mean, var, down1= self.encoder(x)
    z = self.latent_variable(mean, var)
    y = self.decoder(z, down1)
    reconst_loss = -torch.mean(torch.sum(x*torch.log(y) + (1 - x)* torch.log(1 - y), dim=1))
    latent_loss = - 1/2 * torch.mean(torch.sum(1 + torch.log(var) - mean**2 - var, dim=1))
    loss = reconst_loss + latent_loss

    return loss

  def generate_images(self, x, device):
      mean, var, down1 = self.encoder(x)
      z = self.latent_variable(mean, var)
      y = self.decoder(z, down1)
      generated_images = self.decoder(z, down1)
      return generated_images


import torch
print(torch.cuda.is_available())
print(torch.__version__ )

torch.cuda.current_device()

from PIL import Image
import os
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from pathlib import Path
from multiprocessing import Pool, freeze_support, RLock
import numpy as np 

class CatDataset(Dataset):
    def __init__(self, path):
        files = os.listdir(path)
        self.file_list = [os.path.join(path,file) for file in files]
        self.transform = transforms.Compose(
        [
        transforms.Resize(128),
        transforms.RandomHorizontalFlip(0.5),
        transforms.ToTensor(), 
        # 0~1を-1~1に変換
        # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
    )
    def __len__(self):
        return len(self.file_list)
    def __getitem__(self, i):
        img = Image.open(self.file_list[i])
        return self.transform(img)

if __name__ == '__main__':
  freeze_support()
  num_epochs = 200
  batch_size = 32
  learning_rate = 1e-5
  train_path = Path("./train")
  dataset_dir = train_path
  test_path = Path("./test")
  dataset_dir_test = test_path
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  print(device)

  # モデルのインスタンス化
  latent_dim = 64
  model = VAE(3, 3, latent_dim, n_feat).to(device)
  loss_function = model.loss
  # model = VAE(image_size, h1_dim, h2_dim, z_dim).to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

  dataset = CatDataset(dataset_dir)
  dataset_test = CatDataset(dataset_dir_test)
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
  dataloader_test = DataLoader(dataset_test, batch_size=8, shuffle=True, num_workers=4)
  # train
  losses = []
  for epoch in range(num_epochs):
    print(f"Epoch {epoch} : ")
    train_loss = 0
    pbar = tqdm(dataloader)
    for i, x in enumerate(pbar):
      print(x.shape)
      # 予測
      x = x.to(device)
      model.train()
      loss = loss_function(x)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      train_loss += loss.item()
      print("train_loss", train_loss)

    train_loss /= len(dataloader)
    print('Epoch({}) -- loss: {:.3f}'.format(epoch+1, train_loss))
    losses.append(train_loss)

    # モデルの保存
    torch.save(model.state_dict(), f"./vae2/vae_model_epoch{epoch+1}.pt")
    with torch.no_grad():
      for i, x in enumerate(dataloader_test):
        # 画像の生成と保存
        x = x.to(device)
        generated_images = model.generate_images(x, device=device)
        save_image(generated_images, f"./vae_images2/generated_images_epoch{epoch+1}.png", nrow=4)

自分で試したこと

コメントアウトしていますが、エンコーダー、デコーダー共に、down,upの回数を減らしてモデルを軽量化してみましたが、lossの推移を見る限りはほぼ影響が無かったです。
あとは、latent_dimを小さくしたり画像サイズを小さくしてみましたが、効果は見られませんでした。。。

よろしくお願い致します。

0

Your answer might help someone💌