LoginSignup
1
0

Unet+VAE

Last updated at Posted at 2023-06-30

はじめに

今回は、Unetのエンコーダー、デコーダー構造と、VAEの潜在変数への変換を組み合わせたモデルで学習させてみました。
Stable Diffusionでは、画像の次元削減をエンコーダーで行ってからノイズ付与と画像生成を行い、その生成した画像をデコーダーで復元するという手法を取っています。そのため、画像生成モデルを挟むように存在するエンコーダーデコーダーを自分で用意したいと思ったことがきっかけで、本記事のような事を行いました。
画像は(3, 128, 128)カラー画像を用いて学習を行いました。
一応、色も含めてそれなりに復元はできていましたが、復元する画像が全体的に白っぽくなったりしました。
※最初はbatchnormを使用しなかったためか、損失が発散したりしました。
加えて、学習率も損失の収束に影響が見られたため、今回は多々気づきがあって良かったです。
追記にありますが、パラメータ数の削減とbatchnorm2Dを省くことで綺麗に復元ができました。

環境

pytorch==1.12.1+cu113
(+cuと付いていないとGPUが使用できないようなのでお気を付けください。)
gpu:NVIDIA GeForce GTX 1660

データセットについて

前回同様、学習データは下記を使用しました。
https://paperswithcode.com/dataset/afhq
この内、猫の画像のみを用いています。

学習時の設定

・画像は(3, 128, 128)
・データ拡張は左右反転のみ
・batch_size:32
・learning_rate:1e-3(学習に時間がかかるため、試行錯誤ができていないです。。。)
・エンコーダで64次元に圧縮(latent_dim = 64)
・損失関数は、KL_lossとMSELoss
(損失関数はとても重要ですよね。今回は、損失関数を変えると損失の収束に大きく影響が見られましたが、学習をそこまで回せないのであまり考察はできておりません。。。)

コード

まずは使用したパッケージと、gpuの動作確認のコードです。
torch.cuda.is_available()がTrueにならない場合は、importしているtorchに「+cu」が付いていない可能性があるので、確認してみて下さい。
※因みに私は、gpuを認識させるのに嵌りました笑

import torch
import torch.nn as nn
import torch.nn.functional as F
 
# PyTorch画像用
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image

from PIL import Image
import os
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from multiprocessing import Pool, freeze_support, RLock
import numpy as np

# gpuが適用されているかの確認
print(torch.cuda.is_available())
print(torch.__version__ )

torch.cuda.current_device()

次はEncoder/Decoderクラスです。

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.down1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.down2 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.down3 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        #print("x", x.shape)
        down1 = self.down1(x)
        #print("down1", down1.shape)
        down2 = self.down2(down1)
        #print("down2", down2.shape)
        down3 = self.down3(down2)
        #print("down3", down3.shape)
        return down3, down2, down1


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(256, 64, kernel_size=2, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.up3 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(3),
            nn.ReLU(inplace=True)
        )
        self.fc2 = nn.Linear(latent_dim, 256)
        self.fc3 = nn.Linear(256, 256 * 32 * 32)

    def forward(self, x, down1, down2, down3):
        #print("x", x.shape)
        fc2 = F.relu(self.fc2(x))
        #print("fc2", fc2.shape)
        fc3 = F.relu(self.fc3(fc2))
        fc3 = fc3.view(-1, 256, 32, 32)
        #print("fc3", fc3.shape)
        #print("down3", down3.shape)
        # up1 = self.up1(torch.cat([fc3, down3], dim=1))
        up1 = self.up1(fc3)
        #print("up1", up1.shape)
        #print("down2", down2.shape)
        up2 = self.up2(torch.cat([up1, down2], dim=1))
        #print("up2", up2.shape)
        #print("down1", down1.shape)
        up3 = torch.sigmoid(self.up3(torch.cat([up2, down1], dim=1)))
        #print("up3", up3.shape)
        return up3

特筆すべき箇所は、Unetの「skip connection」です。
Encoderのforwardでは、ダウンサンプリング(畳み込み)した3つの特徴量をreturnするようにし、Decoderでは、その3つの値を受け取って対応するアップサンプリング層のch方向に加えています。
こうすることで、特徴の損失を防ぎやすくしています。シンプルに考えると、次元圧縮前の特徴量を段階毎にアップサンプリング時に加えているので、元のデータを復元しやすくなるのはイメージがしやすいと思います。
※forwardでやたらprint文が多いのは、shapeを合わせるために出力を確認しながらパラメータを決めた名残です。結構shapeの調整に時間がかかったので、皆さんはどうやって層のパラメータを決めているのか気になりますね!

では、次はUnetVAEモデルです。(クラス名はVAEのままですが。。。)

# VAE-Unetモデル
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.latent_dim = latent_dim
        self.fc1 = nn.Linear(256 * 32 * 32, 256)
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)

    def encode(self, x):
        down3, down2, down1 = self.encoder(x)
        down3_ = down3.view(down3.size(0), -1)
        #print("down3_", down3_.shape)
        fc1 = F.relu(self.fc1(down3_))
        #print("fc1", fc1.shape)
        mu = self.fc_mu(fc1)
        logvar = self.fc_logvar(fc1)
        logvar = F.softplus(logvar)
        return mu, logvar, down3, down2, down1

    def reparameterize(self, mu, logvar):
        epsilon = torch.randn_like(mu)
        std = torch.exp(0.5 * logvar)
        z = mu + epsilon * std
        return z

    def decode(self, z, down1, down2, down3):
        up3 = self.decoder(z, down1, down2, down3)
        return up3

    def forward(self, x):
        mu, logvar, down3, down2, down1 = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z, down1, down2, down3)
        return recon_x, mu, logvar

    def loss(self, x):
        mu, logvar, down3, down2, down1 = self.encode(x)
        z = self.reparameterize(mu, logvar)
        y = self.decode(z, down1, down2, down3)
        # reconst_loss = torch.sum(x * torch.log(y + np.spacing(1)) + (1 - x) * torch.log(1 - y + np.spacing(1)))
        reconst_loss = nn.MSELoss()(y, x)
        # reconst_loss = F.binary_cross_entropy(y, x, reduction='sum')
        # reconst_loss = nn.CrossEntropyLoss()(y, x)
        latent_loss = - 1/2 * torch.sum(1 + logvar - torch.exp(logvar) - mu**2)
        #print(reconst_loss, latent_loss)
        loss = reconst_loss + latent_loss

        return loss

    def predict(self, x):
        mu, logvar, down3, down2, down1 = self.encode(x)
        z = self.reparameterize(mu, logvar)
        y = self.decode(z, down1, down2, down3)
        # y = (y[:, :, :, :] + 1) / 2
        # print("min", torch.min(y))
        # print("max", torch.max(y))
        return y

こちらの特筆すべき箇所は、encode(self, x)で潜在変数算出用の平均muと分散var(logvarとしていますが分散varです。)を算出し、reparameterize(self, mu, logvar)でそれらを用いて潜在変数zを算出します。このzをデコーダーに渡して復元します。
lossも色々試しましたが、KL_lossとMSELossが現状はいい気がしました。
学習するデータセット依存かもしれませんが。

最後に、データセットローダーと学習部のコードです。

class CatDataset(Dataset):
    def __init__(self, path):
        files = os.listdir(path)
        self.file_list = [os.path.join(path,file) for file in files]
        self.transform = transforms.Compose(
        [
        transforms.Resize(128),
        transforms.RandomHorizontalFlip(0.5),
        transforms.ToTensor(), 
        # 0~1を-1~1に変換
        # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
    )
    def __len__(self):
        return len(self.file_list)
    def __getitem__(self, i):
        img = Image.open(self.file_list[i])
        return self.transform(img)

if __name__ == '__main__':
  freeze_support()
  num_epochs = 2000
  batch_size = 32
  learning_rate = 1e-3
  train_path = Path("./train")
  dataset_dir = train_path
  test_path = Path("./test")
  dataset_dir_test = test_path
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  print(device)

  # save_transform = transforms.Compose([
  #   transforms.ToPILImage(),  # テンソルをPIL画像に変換
  #   transforms.Lambda(lambda x: (x * 0.5 + 0.5)),  # -0.5から0.5の範囲を0から1に変換
  #   transforms.ToTensor()  # PIL画像をテンソルに変換
  #   ])

  # モデルのインスタンス化
  latent_dim = 64
  model = VAE(latent_dim).to(device)
  loss_function = model.loss
  # model = VAE(image_size, h1_dim, h2_dim, z_dim).to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

  dataset = CatDataset(dataset_dir)
  dataset_test = CatDataset(dataset_dir_test)
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
  dataloader_test = DataLoader(dataset_test, batch_size=16, shuffle=True, num_workers=4)
  # train
  losses = []
  for epoch in range(num_epochs):
    print(f"Epoch {epoch} : ")
    train_loss = 0
    pbar = tqdm(dataloader)
    for i, x in enumerate(pbar):
      print(x.shape)
      # 予測
      x = x.to(device)
      model.train()
      loss = loss_function(x)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      train_loss += loss.item()
      # print("train_loss", train_loss)

    train_loss /= len(dataloader)
    print('Epoch({}) -- loss: {:.3f}'.format(epoch+1, train_loss))
    losses.append(train_loss)

    # モデルの保存
    torch.save(model.state_dict(), f"./vae2/vae_model2_epoch{epoch+1}.pt")
    with torch.no_grad():
      for i, x in enumerate(dataloader_test):
        # 画像の生成と保存
        x = x.to(device)
        generated_images = model.predict(x)
        # generated_images = save_transform(output_images)
        save_image(generated_images.view(generated_images.size(0), 3, 128, 128), f"vae_images2/epoch_{epoch+1}.png", nrow=4)
        break

標準化(RGB値を-1~1の範囲に収める)も試しましたが、あまり上手くいかなかった事から正規化のみになっています。また、バッチサイズは2000としていますが全くそこまで回したことはありません。(多くて100回弱程度)

復元結果

本結果もまだ学習途中のものです。
また、今回はlossの推移のグラフが無い(追加し忘れ)ため、次回からは忘れないようにします。。。

epoch:1
epoch_1.png

最初からある程度復元できていますが、全体的に画像が白っぽいです。

epoch:30
epoch_30.png

若干鮮明になりましたが、まだまだ黒っぽい色は暗すぎてますね。

epoch:63(執筆時の最新epoch)
epoch_63.png

この時のlossは「3292.83」でした。
学習率を変えるともっと損失が小さくなったりしましたが(0.05など)、生成される画像は相変わらず白みがかっていました。
ただ、ある程度正確に復元はできています。引き続き学習は続けてみます。
※白みがかる原因に心当たりがありましたら、是非アドバイス頂けると大変助かります。。。

追記:20230701

下記記事を参考にUnetの構造を変更したところ、綺麗に画像復元ができました。
https://qiita.com/phyblas/items/2ad3d70841ca4a888ee4
モデルのパラメータ数が多かったのでしょうか?
パラメータを変更して試そうと思います。

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.down1 = nn.Sequential(
            nn.Conv2d(3,48,3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(48,48,3,padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )

        for i in range(2,6):
            self.add_module('down%d'%i,
                nn.Sequential(
                    nn.Conv2d(48,48,3,stride=1,padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(2)
                )
            )

    def forward(self, x):
        #print("x", x.shape)
        down1 = self.down1(x)
        print("down1", down1.shape)
        down2 = self.down2(down1)
        print("down2", down2.shape)
        down3 = self.down3(down2)
        print("down3", down3.shape)
        down4 = self.down4(down3)
        print("down4", down4.shape)
        down5 = self.down5(down4)
        print("down5", down5.shape)
        return down5, down4, down3, down2, down1


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.fc2 = nn.Linear(latent_dim, 256)
        self.fc3 = nn.Linear(256, 48 * 4 * 4)
        self.up1 = nn.Sequential(
            nn.Conv2d(48,48,3,stride=2,padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(48,48,3,stride=2,padding=1,output_padding=1)
        )

        self.up2 = nn.Sequential(
            nn.Conv2d(96,96,3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(96,96,3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(96,96,3,stride=2,padding=1,output_padding=1)
        )

        for i in range(3,6):
            self.add_module('up%d'%i,
                nn.Sequential(
                    nn.Conv2d(144,96,3,stride=1,padding=1),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(96,96,3,stride=1,padding=1),
                    nn.ReLU(inplace=True),
                    nn.ConvTranspose2d(96,96,3,stride=2,padding=1,output_padding=1)
                )
            )

        self.coli = nn.Sequential(
            nn.Conv2d(144,64,3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64,64,3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64,3,3,stride=2,padding=1,output_padding=1),
            # nn.Conv2d(32,3,3,stride=1,padding=1),
            nn.LeakyReLU(0.1)
        )

    def forward(self, x, down1, down2, down3, down4, down5):
        print("x", x.shape)
        fc2 = F.relu(self.fc2(x))
        print("fc2", fc2.shape)
        fc3 = F.relu(self.fc3(fc2))
        fc3 = fc3.view(-1, 48, 4, 4)
        print("fc3", fc3.shape)
        #up1 = self.up1(torch.cat([fc3, down3], dim=1))
        
        up1 = self.up1(fc3)
        print("up1", up1.shape)
        print("down5", down5.shape)
        up2 = self.up2(torch.cat([up1, down5], dim=1))
        print("up2", up2.shape)
        print("down4", down4.shape)
        up3 = self.up3(torch.cat([up2, down4], dim=1))
        print("up3", up3.shape)
        print("down3", down3.shape)
        up4 = self.up4(torch.cat([up3, down3], dim=1))
        print("up4", up4.shape)
        print("down2", down2.shape)
        up5 = self.up5(torch.cat([up4, down2], dim=1))
        print("up5", up5.shape)
        print("down1", down1.shape)
        out = torch.sigmoid(self.coli(torch.cat((up5,down1), dim=1)))
        print("out", out.shape)
        return out


# VAE-Unetモデル
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.latent_dim = latent_dim
        self.fc1 = nn.Linear(48 * 4 * 4, 256)
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_var = nn.Linear(256, latent_dim)

    def encode(self, x):
        down5, down4, down3, down2, down1 = self.encoder(x)
        down5_ = down5.view(down5.size(0), -1)
        print("down5_", down5_.shape)
        fc1 = F.relu(self.fc1(down5_))
        print("fc1", fc1.shape)
        mu = self.fc_mu(fc1)
        var = self.fc_var(fc1)
        var = F.softplus(var)
        return mu, var, down5, down4, down3, down2, down1

    def reparameterize(self, mu, var):
        eps = torch.randn(mu.size())
        # モデル定義時にgpuに渡しているが、何故かここでエラーが生じるのでepsをgpuに渡している
        eps = eps.to(device)
        z = mu + torch.sqrt(var)*eps
        return z

    def decode(self, x, down1, down2, down3, down4, down5):
        up5 = self.decoder(x, down1, down2, down3, down4, down5)
        return up5

    def forward(self, x):
        mu, var, down5, down4, down3, down2, down1 = self.encode(x)
        z = self.reparameterize(mu, var)
        recon_x = self.decode(z, down1, down2, down3, down4, down5)
        return recon_x, mu, var

    def loss(self, x):
        mu, var, down5, down4, down3, down2, down1 = self.encode(x)
        z = self.reparameterize(mu, var)
        y = self.decode(z, down1, down2, down3, down4, down5)
        # reconst_loss = torch.sum(x * torch.log(y + np.spacing(1)) + (1 - x) * torch.log(1 - y + np.spacing(1)))
        reconst_loss = nn.MSELoss()(y, x)
        # reconst_loss = F.binary_cross_entropy(y, x, reduction='sum')
        # reconst_loss = nn.CrossEntropyLoss()(y, x)
        latent_loss = - 1/2 * torch.sum(1 + var - torch.exp(var) - mu**2)
        # reconst_loss = -torch.mean(torch.sum(x*torch.log(y) + (1 - x)* torch.log(1 - y), dim=1))
        # latent_loss = - 1/2 * torch.mean(torch.sum(1 + torch.log(var) - mu**2 - var, dim=1))
        #print(reconst_loss, latent_loss)
        loss = reconst_loss + latent_loss

        return loss

    def predict(self, x):
        mu, var, down5, down4, down3, down2, down1 = self.encode(x)
        z = self.reparameterize(mu, var)
        y = self.decode(z, down1, down2, down3, down4, down5)
        # y = (y[:, :, :, :] + 1) / 2
        # print("min", torch.min(y))
        # print("max", torch.max(y))
        return y

epoch:230
epoch_230.png
ほぼ、現画像レベルで復元ができています!

追記:20230702

モデルのパラメータ数の削減と、batchnorm2Dを外すことで復元ができました。
今回は、前述したデータセット全てを使用してみました。

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.down1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.down3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.down4 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.down5 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.down6 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

    def forward(self, x):
        #print("x", x.shape)
        down1 = self.down1(x)
        #print("down1", down1.shape)
        down2 = self.down2(down1)
        #print("down2", down2.shape)
        down3 = self.down3(down2)
        #print("down3", down3.shape)
        down4 = self.down4(down3)
        #print("down4", down4.shape)
        down5 = self.down5(down4)
        #print("down5", down5.shape)
        down6 = self.down6(down5)
        #print("down6", down6.shape)
        return down6, down5, down4, down3, down2, down1


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.fc2 = nn.Linear(latent_dim, 128)
        self.fc3 = nn.Linear(128, 128 * 2 * 2)
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2),
            #nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
        )
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            #nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.up3 = nn.Sequential(
            nn.ConvTranspose2d(192, 128, kernel_size=2, stride=2),
            #nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.up4 = nn.Sequential(
            nn.ConvTranspose2d(192, 128, kernel_size=2, stride=2),
            #nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.up5 = nn.Sequential(
            nn.ConvTranspose2d(96, 64, kernel_size=2, stride=2),
            #nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.up6 = nn.Sequential(
            nn.ConvTranspose2d(96, 3, kernel_size=2, stride=2),
            #nn.BatchNorm2d(3),
            nn.ReLU(inplace=True),
            nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(3),
            nn.ReLU(inplace=True),
            nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1),
            # nn.BatchNorm2d(3),
            nn.LeakyReLU(0.1)
        )

    def forward(self, x, down1, down2, down3, down4, down5, down6):
        #print("x", x.shape)
        fc2 = F.relu(self.fc2(x))
        #print("fc2", fc2.shape)
        fc3 = F.relu(self.fc3(fc2))
        fc3 = fc3.view(-1, 128, 2, 2)
        #print("fc3", fc3.shape)
        #print("down6", down6.shape)
        # up1 = self.up1(torch.cat([fc3, down3], dim=1))
        up1 = self.up1(fc3)
        #print("up1", up1.shape)
        #print("down5", down5.shape)
        up2 = self.up2(torch.cat([up1, down5], dim=1))
        #print("up2", up2.shape)
        #print("down4", down4.shape)
        up3 = self.up3(torch.cat([up2, down4], dim=1))
        #print("up3", up3.shape)
        #print("down3", down3.shape)
        up4 = self.up4(torch.cat([up3, down3], dim=1))
        #print("up4", up4.shape)
        #print("down2", down2.shape)
        up5 = self.up5(torch.cat([up4, down2], dim=1))
        #print("up5", up5.shape)
        #print("down1", down1.shape)
        up6 = torch.sigmoid(self.up6(torch.cat([up5, down1], dim=1)))
        #print("up6", up6.shape)
        return up6


# VAE-Unetモデル
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.latent_dim = latent_dim
        self.fc1 = nn.Linear(128 * 2 * 2, 128)
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_var = nn.Linear(128, latent_dim)

    def encode(self, x):
        down6, down5, down4, down3, down2, down1 = self.encoder(x)
        down6_ = down6.view(down6.size(0), -1)
        #print("down6_", down6_.shape)
        fc1 = F.relu(self.fc1(down6_))
        #print("fc1", fc1.shape)
        mu = self.fc_mu(fc1)
        var = self.fc_var(fc1)
        var = F.softplus(var)
        return mu, var, down6, down5, down4, down3, down2, down1

    def reparameterize(self, mu, var):
        eps = torch.randn(mu.size())
        # モデル定義時にgpuに渡しているが、何故かここでエラーが生じるのでepsをgpuに渡している
        eps = eps.to(device)
        z = mu + torch.sqrt(var)*eps
        return z

    def decode(self, z, down1, down2, down3, down4, down5, down6):
        up3 = self.decoder(z, down1, down2, down3, down4, down5, down6)
        return up3

    def forward(self, x):
        mu, var, down6, down5, down4, down3, down2, down1 = self.encode(x)
        z = self.reparameterize(mu, var)
        recon_x = self.decode(z, down1, down2, down3, down4, down5, down6)
        return recon_x, mu, var

    def loss(self, x):
        mu, var, down6, down5, down4, down3, down2, down1 = self.encode(x)
        z = self.reparameterize(mu, var)
        y = self.decode(z, down1, down2, down3, down4, down5, down6)
        # reconst_loss = torch.sum(x * torch.log(y + np.spacing(1)) + (1 - x) * torch.log(1 - y + np.spacing(1)))
        reconst_loss = nn.MSELoss()(y, x)
        # reconst_loss = F.binary_cross_entropy(y, x, reduction='sum')
        # reconst_loss = nn.CrossEntropyLoss()(y, x)
        latent_loss = - 1/2 * torch.sum(1 + var - torch.exp(var) - mu**2)
        # reconst_loss = -torch.mean(torch.sum(x*torch.log(y) + (1 - x)* torch.log(1 - y), dim=1))
        # latent_loss = - 1/2 * torch.mean(torch.sum(1 + torch.log(var) - mu**2 - var, dim=1))
        #print(reconst_loss, latent_loss)
        loss = reconst_loss + latent_loss

        return loss

    def predict(self, x):
        mu, var, down6, down5, down4, down3, down2, down1 = self.encode(x)
        z = self.reparameterize(mu, var)
        y = self.decode(z, down1, down2, down3, down4, down5, down6)
        # y = (y[:, :, :, :] + 1) / 2
        # print("min", torch.min(y))
        # print("max", torch.max(y))
        return y

元々のモデルのBatchNorm2dをコメントアウトしても変わらなかったため、モデルのパラメータ数が多すぎていることが原因と考えましたが、上記コードでBatchNorm2dを用いると白っぽくなったため、両者影響していたのだと思います。

epoch:150
epoch_150.png

epoch:160
epoch_160.png

今後は、学習データのクラス数や画像の統計値を見つつ、モデルの構造との関係を見ていきたいですね。
(クラス数が多く学習データの標準偏差が大きい程、モデルのパラメータ数が多いかつBatchNorm2dが効いてくると考えています。)

終わりに

今回は、VAEだと基本的には線形層のみで心もとないことと、エンコーダーデコーダー構造であれば、中間層をVAEに置き換えても動くのでは?と思い試してみました。
複数の動物の顔画像であれば復元ができました。
顔だけでなく身体全体の画像でも復元できるか、今後試そうと思います。
次はVQVAEを動かそうと思います。stable-diffusionで用いられているみたいですね。

では、またの機会に!

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0