0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Unet, VAE+Unet, Dncnnを用いて、ガウスノイズ画像を復元してみた

Last updated at Posted at 2023-07-14

はじめに

ピンボケ画像の復元をしたいと思い、いくつか検証を行ったので記事にしました。
ピンボケ画像は、一般的にはガウシアンフィルタ(ぼかしフィルタ)に近似できるとのことで、当初はフーリエ変換を用いた方法(ウィーナフィルタ)を検討していました。しかし、撮影環境が変わったりピンボケの拡がり方が多様な場合は、汎化性能的に深層学習の方が優位性があるかなと思い、深層学習のモデルを用いた検証を行いました。
調べてみると、Dncnnなどノイズ除去目的のモデルがあったため、dncnn含め以前作成したUnetとVAE+Unetを用いました。
加えて、モデルは復元させる綺麗な画像自体を学習するよりもノイズを学習しやすい傾向があるとのことで、(stable diffusionも考え方は似ていますよね。こちらもノイズを正規分布として仮定しているため、考え方はほぼ同じな気がします。)Unetの出力をノイズを学習させるよう、「出力 = 入力 - デコーダーの出力」としました。
また、精度検証にPSNRを用いて比較を行いました。

環境

pytorch==1.12.1+cu113
gpu:NVIDIA GeForce GTX 1660
anaconda
※パッケージは多々用いていますが、ここでの記載は省いています。
詳細はソースコード内を見て頂けると幸いです。

データセットについて

前回同様、学習データは下記を使用しました。
https://paperswithcode.com/dataset/afhq

学習時の設定

画像サイズ:(3, 128, 128)
batch_size:32
learning_rate:1e-3
その他:各パターンのコードを載せるので、そちらを確認頂きたいです。

Dncnnについて

下記参考にさせて頂きました。
https://qiita.com/jw-automation/items/f942ea0c6a02e8e50fa2#:~:text=DnCNN%EF%BC%88Denoising%20Convolutional%20Neural%20Network%EF%BC%89,-DnCNN%E3%81%AF%E3%80%8CDenosing&text=BatchNormalization%E3%81%A8%E3%81%AF%E3%80%81%E3%83%8B%E3%83%A5%E3%83%BC%E3%83%A9%E3%83%AB%E3%83%8D%E3%83%83%E3%83%88,%E5%90%91%E4%B8%8A%E3%81%AB%E5%AF%84%E4%B8%8E%E3%81%97%E3%81%BE%E3%81%99%E3%80%82

構造を見ると畳み込みとバッチ正規化を繰り返す単純なモデルではありますが、出力の形状が一番のポイントだと思います。
こちら(https://techblog.leapmind.io/blog/20220104-rimamura-survey_on_nr_model/) にも書かれていますが、綺麗な画像を復元するように学習するよりもノイズを学習させることで精度を向上させているようです。確かに、直感的には画像自体よりもノイズを学習する方が複雑度が低い気はしますね。(今回は、このノイズを学習させるという考え方をUnetのモデルにも適応させてみました。)
単純に、ノイズ画像を入力として最後の層の出力をノイズと仮定し、入力とノイズの差分を最終的なモデルの出力とします。この出力と綺麗な画像を損失関数に与えることで、出力を綺麗な画像に寄せるように学習する、つまり最終層の出力はノイズを学習するようになります。

PSNRについて

下記参考にさせて頂きました。
https://qiita.com/Daiki_P/items/94662fd340aa0381b323

類似度の指標であり、MSE以外は定数になります。
そのため、MSEが小さい、つまり2つの画像の差が小さい場合はPSNRが大きくなります。
逆に、MSEが大きい、つまり2つの画像の差が大きい場合はPSNRが小さくなります。
よって、今回はPSNRが最も大きくなるモデルを確認すればいいことになります。
(勿論、オリジナルと復元画像も確認します。)

ソースコード:共通部分

下記、パッケージです。

import torch
import torch.nn as nn
import torch.nn.functional as F
 
# PyTorch画像用
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image

from PIL import Image
from PIL import Image, ImageFilter
import os
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from multiprocessing import Pool, freeze_support, RLock
import numpy as np
import matplotlib.pyplot as plt

print(torch.cuda.is_available())
print(torch.__version__ )

torch.cuda.current_device()

下記、データローダーです。
オリジナル画像と、ガウシアンフィルタを掛けた画像を出力するようにしています。

class CatDataset(Dataset):
    def __init__(self, path):
        files = os.listdir(path)
        self.file_list = [os.path.join(path,file) for file in files]
        self.transform = transforms.Compose(
        [
        transforms.Resize(128),
        transforms.ToTensor(), 
        # 0~1を-1~1に変換
        # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
    )
    def __len__(self):
        return len(self.file_list)
    def __getitem__(self, i):
        img_org = Image.open(self.file_list[i])
        if self.transform:
            original_image = self.transform(img_org)
        
        # ぼかした画像の作成
        blurred_image = img_org.filter(ImageFilter.GaussianBlur(radius=5))
        if self.transform:
            blurred_image = self.transform(blurred_image)
        
        return original_image, blurred_image
        # return self.transform(img)

また、psnrの計算部です。
前回と同じものを用いています。

def calculate_psnr(original, reconstructed):
    original = original.to(device)
    reconstructed = reconstructed.to(device)
    mse = F.mse_loss(original, reconstructed)  # 平均二乗誤差の計算
    mse = mse.to(device)
    psnr = 10 * torch.log10(1 / mse)  # PSNRの計算
    return psnr

Unetのモデルです。
出力変更版はコメントにあるように、forward部のみ変えています。

class Unet(nn.Module):
    def __init__(self,cn=3):
        super(Unet,self).__init__()

        self.copu1 = nn.Sequential(
            nn.Conv2d(cn,48,3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(48,48,3,padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )

        for i in range(2,6):
            self.add_module('copu%d'%i,
                nn.Sequential(
                    nn.Conv2d(48,48,3,stride=1,padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(2)
                )
            )

        self.coasa1 = nn.Sequential(
            nn.Conv2d(48,48,3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(48,48,3,stride=2,padding=1,output_padding=1)
        )

        self.coasa2 = nn.Sequential(
            nn.Conv2d(96,96,3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(96,96,3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(96,96,3,stride=2,padding=1,output_padding=1)
        )

        for i in range(3,6):
            self.add_module('coasa%d'%i,
                nn.Sequential(
                    nn.Conv2d(144,96,3,stride=1,padding=1),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(96,96,3,stride=1,padding=1),
                    nn.ReLU(inplace=True),
                    nn.ConvTranspose2d(96,96,3,stride=2,padding=1,output_padding=1)
                )
            )

        self.coli = nn.Sequential(
            nn.Conv2d(96+cn,64,3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64,32,3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32,cn,3,stride=1,padding=1),
            nn.LeakyReLU(0.1)
        )

        for l in self.modules(): # 重みの初期値
            if(type(l) in (nn.ConvTranspose2d,nn.Conv2d)):
                nn.init.kaiming_normal_(l.weight.data)
                l.bias.data.zero_()

    def forward(self,x):
        x1 = self.copu1(x)
        x2 = self.copu2(x1)
        x3 = self.copu3(x2)
        x4 = self.copu4(x3)
        x5 = self.copu5(x4)

        z = self.coasa1(x5)
        z = self.coasa2(torch.cat((z,x4),1))
        z = self.coasa3(torch.cat((z,x3),1))
        z = self.coasa4(torch.cat((z,x2),1))
        z = self.coasa5(torch.cat((z,x1),1))

        return self.coli(torch.cat((z,x),1))

    # Unetの出力変更版
    # def forward(self,x):
    #    x1 = self.copu1(x)
    #    x2 = self.copu2(x1)
    #    x3 = self.copu3(x2)
    #    x4 = self.copu4(x3)
    #    x5 = self.copu5(x4)

    #    z = self.coasa1(x5)
    #    z = self.coasa2(torch.cat((z,x4),1))
    #    z = self.coasa3(torch.cat((z,x3),1))
    #    z = self.coasa4(torch.cat((z,x2),1))
    #    z = self.coasa5(torch.cat((z,x1),1))
    #    z_noize = self.coli(torch.cat((z,x),1))
    #    out = x - z_noize
    #    return out, z_noize

VAE+Unetです。
こちらも、前回使用したものと同じです。

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.down1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.down3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.down4 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.down5 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.down6 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

    def forward(self, x):
        #print("x", x.shape)
        down1 = self.down1(x)
        #print("down1", down1.shape)
        down2 = self.down2(down1)
        #print("down2", down2.shape)
        down3 = self.down3(down2)
        #print("down3", down3.shape)
        down4 = self.down4(down3)
        #print("down4", down4.shape)
        down5 = self.down5(down4)
        #print("down5", down5.shape)
        down6 = self.down6(down5)
        #print("down6", down6.shape)
        return down6, down5, down4, down3, down2, down1


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.fc2 = nn.Linear(latent_dim, 128)
        self.fc3 = nn.Linear(128, 128 * 2 * 2)
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2),
            #nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
        )
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            #nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.up3 = nn.Sequential(
            nn.ConvTranspose2d(192, 128, kernel_size=2, stride=2),
            #nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.up4 = nn.Sequential(
            nn.ConvTranspose2d(192, 128, kernel_size=2, stride=2),
            #nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.up5 = nn.Sequential(
            nn.ConvTranspose2d(96, 64, kernel_size=2, stride=2),
            #nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.up6 = nn.Sequential(
            nn.ConvTranspose2d(96, 3, kernel_size=2, stride=2),
            #nn.BatchNorm2d(3),
            nn.ReLU(inplace=True),
            nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1),
            #nn.BatchNorm2d(3),
            nn.ReLU(inplace=True),
            nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1),
            # nn.BatchNorm2d(3),
            nn.LeakyReLU(0.1)
        )

    def forward(self, x, down1, down2, down3, down4, down5, down6):
        #print("x", x.shape)
        fc2 = F.relu(self.fc2(x))
        #print("fc2", fc2.shape)
        fc3 = F.relu(self.fc3(fc2))
        fc3 = fc3.view(-1, 128, 2, 2)
        #print("fc3", fc3.shape)
        #print("down6", down6.shape)
        # up1 = self.up1(torch.cat([fc3, down3], dim=1))
        up1 = self.up1(fc3)
        #print("up1", up1.shape)
        #print("down5", down5.shape)
        up2 = self.up2(torch.cat([up1, down5], dim=1))
        #print("up2", up2.shape)
        #print("down4", down4.shape)
        up3 = self.up3(torch.cat([up2, down4], dim=1))
        #print("up3", up3.shape)
        #print("down3", down3.shape)
        up4 = self.up4(torch.cat([up3, down3], dim=1))
        #print("up4", up4.shape)
        #print("down2", down2.shape)
        up5 = self.up5(torch.cat([up4, down2], dim=1))
        #print("up5", up5.shape)
        #print("down1", down1.shape)
        up6 = torch.sigmoid(self.up6(torch.cat([up5, down1], dim=1)))
        #print("up6", up6.shape)
        return up6


# VAE-Unetモデル
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.latent_dim = latent_dim
        self.fc1 = nn.Linear(128 * 2 * 2, 128)
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_var = nn.Linear(128, latent_dim)

    def encode(self, x):
        down6, down5, down4, down3, down2, down1 = self.encoder(x)
        down6_ = down6.view(down6.size(0), -1)
        #print("down6_", down6_.shape)
        fc1 = F.relu(self.fc1(down6_))
        #print("fc1", fc1.shape)
        mu = self.fc_mu(fc1)
        var = self.fc_var(fc1)
        var = F.softplus(var)
        return mu, var, down6, down5, down4, down3, down2, down1

    def reparameterize(self, mu, var):
        eps = torch.randn(mu.size())
        # モデル定義時にgpuに渡しているが、何故かここでエラーが生じるのでepsをgpuに渡している
        eps = eps.to(device)
        z = mu + torch.sqrt(var)*eps
        return z

    def decode(self, z, down1, down2, down3, down4, down5, down6):
        up3 = self.decoder(z, down1, down2, down3, down4, down5, down6)
        return up3

    def forward(self, x):
        mu, var, down6, down5, down4, down3, down2, down1 = self.encode(x)
        z = self.reparameterize(mu, var)
        recon_x = self.decode(z, down1, down2, down3, down4, down5, down6)
        return recon_x, mu, var

    def loss(self, x_org, x_gauss):
        mu, var, down6, down5, down4, down3, down2, down1 = self.encode(x_gauss)
        z = self.reparameterize(mu, var)
        y = self.decode(z, down1, down2, down3, down4, down5, down6)
        # reconst_loss = torch.sum(x * torch.log(y + np.spacing(1)) + (1 - x) * torch.log(1 - y + np.spacing(1)))
        reconst_loss = nn.MSELoss()(y, x_org)
        # reconst_loss = F.binary_cross_entropy(y, x, reduction='sum')
        # reconst_loss = nn.CrossEntropyLoss()(y, x)
        latent_loss = - 1/2 * torch.sum(1 + var - torch.exp(var) - mu**2)
        # reconst_loss = -torch.mean(torch.sum(x*torch.log(y) + (1 - x)* torch.log(1 - y), dim=1))
        # latent_loss = - 1/2 * torch.mean(torch.sum(1 + torch.log(var) - mu**2 - var, dim=1))
        #print(reconst_loss, latent_loss)
        loss = reconst_loss + latent_loss

        return loss

    def predict(self, x):
        mu, var, down6, down5, down4, down3, down2, down1 = self.encode(x)
        z = self.reparameterize(mu, var)
        y = self.decode(z, down1, down2, down3, down4, down5, down6)
        # y = (y[:, :, :, :] + 1) / 2
        # print("min", torch.min(y))
        # print("max", torch.max(y))
        return y

Dncnnモデルです。
こちらを使用させて頂きました。
https://github.com/SaoYan/DnCNN-PyTorch/blob/master/models.py

※出力箇所を入力とデコーダー出力との差分にすべきでしたが忘れていました。。。

class DnCNN(nn.Module):
    def __init__(self, channels, num_of_layers=17):
        super(DnCNN, self).__init__()
        kernel_size = 3
        padding = 1
        features = 64
        layers = []
        layers.append(nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
        layers.append(nn.ReLU(inplace=True))
        for _ in range(num_of_layers-2):
            layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
            layers.append(nn.BatchNorm2d(features))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False))
        self.dncnn = nn.Sequential(*layers)
    def forward(self, x):
        out = self.dncnn(x)
        return out

結果

テスト画像のオリジナルとガウスフィルタを掛けた画像です。
オリジナル
epoch_100_org.png

ガウスフィルタ
epoch_100_gauss.png

Unetの復元画像
epoch_100_output.png

Lossの推移
unet_result.png

PSNRの推移
unet_Transition_result.png

VAE+Unetの復元画像
epoch_100_org.png

Lossの推移
vaeunet_result.png

PSNRの推移
vaeunet_Transition_result.png

Unet(input - デコーダ出力)の復元画像
epoch_100_output.png

Lossの推移/PSNRの推移
一番いい結果だったのですが、推移のデータを上書きしてしまいました。。。
PSNRは39近くまで上がりました。

Dncnnの復元画像
epoch_100_output.png

Lossの推移
dncnn_result.png

PSNRの推移
dncnn_Transition_result.png

20230805追記

今度は、ガウスノイズをランダムに生成させてみました。
ガウスノイズをピンボケと仮定する場合、ノイズが一定になる場合はほぼないと考えられます。
そこで、ガウスノイズをランダムで生成しても上手く復元できるかを確認しました。
モデルは、unet+vqvaeを使用しました。出力は、input - デコーダ出力でノイズを学習させる方法を取りました。
(vqvaeについては、どこかで書ければと思います。下記、参考にさせて頂いたサイトです。https://data-analytics.fun/2022/01/27/pytorch-vq-vae/ )
vqvaeはvaeと構造はほぼ同じで、エンコーダーで抽出した特徴を異なる空間に変換し、デコーダーでその変換後の特徴を復元する流れになっています。
vaeとの違いは、vaeは正規分布を潜在変数としていましたが、vqvaeは一様分布のようになっています。
エンコーダーで抽出した特徴と近い(ユークリッド距離の近さ)潜在変数を選択します。
※ユークリッド距離とコサイン類似度についてhttps://enjoyworks.jp/tech-blog/2242
トランスフォーマーのQKVと考え方は近いと思います。
では、結果を記載していきます。
オリジナル画像
epoch_100_org.png
ランダムにガウスフィルタを掛けた画像
epoch_100_gauss.png
復元画像
epoch_100_output.png
Lossの推移
unet_vqvae_result.png
PSNRの推移
unet_vqvae_Transition_result.png

まとめ

結果的には、VQVAE + Unet(input - デコーダ出力)形式が最もオリジナル画像に近くなりました。
PSNRも45近くと、他の結果と比べてもいい結果になりました。
input - デコーダ出力形式の方が復元結果が良かったことから、冒頭でも述べたようにモデルはノイズを学習するように設計した方が、ノイズ除去に関しては良い精度になる可能性が高そうですね。
加えて、VQVAE形式の方が通常のUnetよりもいい結果になりました。QKV形式を用いるトランスフォーマーが良いといわれる一因なのでしょうか?
Dncnnもinput - デコーダ出力形式を試すべき(というより、正しい形式がこちらですかね?)ですが、Unetに比べて学習に時間がかかることから、今回は考察までに留めておきます。

終わりに

ピンボケ画像の復元を目的に本記事のような検証を行いました。
ノイズ除去モデルをエンコーダーデコーダー形式で行う場合は、出力形式を”input - デコーダ出力”にした方が良い結果でそうだと分かりました。また、PSNRも画像復元の指標としては中々使えるなと感じました。
VQVAEの記事を書いてからVQVAE + Unetの結果を載せるべきですが、先に載せてしまいました。
transformer関連も色々試しているので、順番は適当で書いていくと思います。
では、また次の機会に。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?