26
24

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

PytorchAdvent Calendar 2019

Day 23

ノイズや落書きいっぱいの画像や、ズームでぶれた画像を復元するCNN (U-net, DnCNN, WIN5-RB)

Last updated at Posted at 2019-12-21

はじめに

汚れた画像を綺麗にしたり、落書きを画像から消したり、ズームで品質が落ちた画像を復元したりするために、最近深層学習が使われる方法です。

原理は簡単、ノイズのある画像を変換器に入力して、綺麗な画像を出力させるだけです。

そのような変換器というのはよく畳み込みニューラルネットワークCNN)から成なされるものです。

この記事では、そのようなCNNをpytorchで実装してみて、結果を見てみます。

このようなCNNのモデルは色々ありますが、今回は3種のCNNモデルで試して結果を比べてみます。

概要

ノイズのある画像から綺麗な画像に変える変換器を作るためには、まずCNNモデルを決めることです。

そのCNNはノイズのある画像を入力として、出力は入力と同じ大きさの画像。

モデル通りCNNを作ったら、まず最初はCNNの中のパラメータはランダムだから最初の出力は出鱈目であるはず。

その結果は模範の画像と比べてMSEmean square error)を計算して誤差逆伝播法でパラメータを更新していって、出力は綺麗になっていく。

結局、出力の画像は模範とほぼ変わらないものになります。

学習完了した後、学習に使われなかった画像を入力しても、同じようにノイズが消されるでしょう。(そうじゃない場合は過学習になっているということ)

もっと詳しくはqiitaで参考になる記事がたくさんあります。

使う画像データの準備

今回使う画像はこの前DCGANの実装を試したと同じくsafebooruから取得したものです。

詳しくはDCGANの記事で

画像をダウンロードして、256x256サイズに切り取って、このような画像ができます。

そしてこのように3つの方法で画像をいじります。

  • ガウス雑音

  • 落書き

  • 縮小して再び拡大

どちらにせよ、ノイズを加えて画像の品質を落とすことです。そのような画像をCNNで復元してみます。

ただし、CNNの学習をする時に予めこのような画像を予め準備するのではなく、毎回綺麗なサンプルからランダムでノイズを入れて使います。こうやって毎回ランダムで違うノイズからなる画像が使われます。

使うモデル

今回使おうとしたのは4種のモデル

U-net
論文 Ronneberger & Fischer & Brox 2015
コード https://github.com/DuFanXin/U-net , https://github.com/milesial/Pytorch-UNet

RED-Net
論文 Mao & Shen & Yang 2016
コード https://github.com/ved27/RED-net

DnCNN
論文 Zhang et al. 2017
コード https://github.com/cszn/DnCNN

WIN5-RB
論文 Liu & Fang 2017
コード https://github.com/cswin/WIN

ただしRED-Netは結局大きすぎて、使おうとしてみましたが、時間がかかりすぎて効果的ではないため、結局諦めました。

各モデルの構造を絵で説明します。
画像はadobe illustratorextendscriptで書いてみたものです。

U-net

unet.png

RED-Net

rednet.png

DnCNN

dncnn.png

WIN5-RB

win5rb.png

実装

今回実装してみたコード

from glob import glob
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.image import imsave,imread
from skimage.transform import resize
import os,time,torch
from torch import nn
mse = nn.MSELoss()

## ニューラルネットワークモデル ##

class Unet(nn.Module):
    def __init__(self,cn=3):
        super(Unet,self).__init__()

        self.copu1 = nn.Sequential(
            nn.Conv2d(cn,48,3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(48,48,3,padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )

        for i in range(2,6):
            self.add_module('copu%d'%i,
                nn.Sequential(
                    nn.Conv2d(48,48,3,stride=1,padding=1),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(2)
                )
            )

        self.coasa1 = nn.Sequential(
            nn.Conv2d(48,48,3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(48,48,3,stride=2,padding=1,output_padding=1)
        )

        self.coasa2 = nn.Sequential(
            nn.Conv2d(96,96,3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(96,96,3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(96,96,3,stride=2,padding=1,output_padding=1)
        )

        for i in range(3,6):
            self.add_module('coasa%d'%i,
                nn.Sequential(
                    nn.Conv2d(144,96,3,stride=1,padding=1),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(96,96,3,stride=1,padding=1),
                    nn.ReLU(inplace=True),
                    nn.ConvTranspose2d(96,96,3,stride=2,padding=1,output_padding=1)
                )
            )

        self.coli = nn.Sequential(
            nn.Conv2d(96+cn,64,3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64,32,3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32,cn,3,stride=1,padding=1),
            nn.LeakyReLU(0.1)
        )

        for l in self.modules(): # 重みの初期値
            if(type(l) in (nn.ConvTranspose2d,nn.Conv2d)):
                nn.init.kaiming_normal_(l.weight.data)
                l.bias.data.zero_()

    def forward(self,x):
        x1 = self.copu1(x)
        x2 = self.copu2(x1)
        x3 = self.copu3(x2)
        x4 = self.copu4(x3)
        x5 = self.copu5(x4)

        z = self.coasa1(x5)
        z = self.coasa2(torch.cat((z,x4),1))
        z = self.coasa3(torch.cat((z,x3),1))
        z = self.coasa4(torch.cat((z,x2),1))
        z = self.coasa5(torch.cat((z,x1),1))

        return self.coli(torch.cat((z,x),1))



class DnCNN(nn.Sequential):
    def __init__(self,cn=3):
        super(DnCNN,self).__init__()
        
        self.add_module('coba1',nn.Sequential(
                nn.Conv2d(cn,64,3,stride=1,padding=1),
                nn.ReLU(inplace=True),
            )
        )

        for i in range(2,17):
            self.add_module('coba%d'%i,
                nn.Sequential(
                    nn.Conv2d(64,64,3,stride=1,padding=1),
                    nn.BatchNorm2d(64),
                    nn.ReLU(inplace=True),
                )
            )
        
        self.add_module('coba17',nn.Conv2d(64,cn,3,stride=1,padding=1))
        
        for l in self.modules(): # 重みの初期値
            if(type(l)==nn.Conv2d):
                nn.init.kaiming_normal_(l.weight.data)
                l.bias.data.zero_()
            elif(type(l)==nn.BatchNorm2d):
                l.weight.data.fill_(1)
                l.bias.data.zero_()



class Win5RB(nn.Module):
    def __init__(self,cn=3):
        super(Win5RB,self).__init__()

        inc = cn
        for i in range(1,5):
            self.add_module('coba%d'%i,
                nn.Sequential(
                    nn.Conv2d(inc,64,7,stride=1,padding=3),
                    nn.BatchNorm2d(64),
                    nn.ReLU(inplace=True),
                )
            )
            inc = 64
        
        self.add_module('coba5',nn.Sequential(
                nn.Conv2d(64,cn,7,stride=1,padding=3),
                nn.BatchNorm2d(cn)
            )
        )
        
        for l in self.modules(): # 重みの初期値
            if(type(l)==nn.Conv2d):
                nn.init.kaiming_normal_(l.weight.data)
                l.bias.data.zero_()
            elif(type(l)==nn.BatchNorm2d):
                l.weight.data.fill_(1)
                l.bias.data.zero_()

    def forward(self,x):
        z = self.coba1(x)
        z = self.coba2(z)
        z = self.coba3(z)
        z = self.coba4(z)
        return self.coba5(z) + x



class Rednet(nn.Module):
    def __init__(self,cn=3):
        super(Rednet,self).__init__()

        self.add_module('con1',nn.Sequential(
                nn.Conv2d(cn,64,7,stride=1,padding=3),
                nn.ReLU(inplace=True)
            )
        )

        for i in range(2,16):
            self.add_module('con%d'%i,
                nn.Sequential(
                    nn.Conv2d(64,64,7,stride=1,padding=3),
                    nn.ReLU(inplace=True)
                )
            )

        for i in range(1,15):
            self.add_module('cont%d'%i,
                nn.Sequential(
                    nn.ConvTranspose2d(64,64,7,stride=1,padding=3),
                    nn.ReLU(inplace=True)
                )
            )

        self.add_module('cont15',nn.ConvTranspose2d(64,cn,7,stride=1,padding=3))

        for l in self.modules(): # 重みの初期値
            if(type(l) in (nn.ConvTranspose2d,nn.Conv2d)):
                nn.init.kaiming_normal_(l.weight.data)
                l.bias.data.zero_()

    def forward(self,x):
        z = self.con1(x)
        x2 = self.con2(z)
        z = self.con3(x2)
        x4 = self.con4(z)
        z = self.con5(x4)
        x6 = self.con6(z)
        z = self.con7(x6)
        x8 = self.con8(z)
        z = self.con9(x8)
        x10 = self.con10(z)
        z = self.con11(x10)
        x12 = self.con12(z)
        z = self.con13(x12)
        x14 = self.con14(z)
        z = self.con15(x14)

        z = self.cont1(z)
        z = self.cont2(x14+z)
        z = self.cont3(z)
        z = self.cont4(x12+z)
        z = self.cont5(z)
        z = self.cont6(x10+z)
        z = self.cont7(z)
        z = self.cont8(x8+z)
        z = self.cont9(z)
        z = self.cont10(x6+z)
        z = self.cont11(z)
        z = self.cont12(x4+z)
        z = self.cont13(z)
        z = self.cont14(x2+z)
        z = self.cont15(z)

        return z


## データローダ ##

class Gazoudalo:
    def __init__(self,folder,fnoise,px=256,n_batch=4,random=False):
        self.folder = folder # データの保存したフォルダ
        self.fnoise = fnoise # 使うノイズ関数
        self.px = px # 画像の大きさ
        self.n_batch = n_batch # バッチサイズ
        self.random = random # ランダムするかどうか
        self.file = np.sort(glob(os.path.join(folder,'*.jpg'))) # フォルダの中の全部の画像の名前
        self.len = len(self.file) # 画像の枚数
        self.nkai = int(np.ceil(self.len/n_batch)) # このバッチサイズで繰り返すことができる回数
    
    def __iter__(self):
        self.i_iter = 0
        if(self.random):
            np.random.seed(None)
            # ランダムで順番を入れ替える
            self.i_rand = np.random.permutation(self.len)
        else:
            np.random.seed(0)
            self.i_rand = np.arange(self.len)
        return self
    
    def __next__(self):
        if(self.i_iter>=self.len):
            raise StopIteration
        
        x = []
        y = []
        for i in self.i_rand[self.i_iter:self.i_iter+self.n_batch]:
            yi = imread(self.file[i])/255. # 画像読み込み
            if(self.random and np.random.random()>0.5):
                yi = yi[:,::-1] # ランダムで画像の左右を逆転する
            xi = self.fnoise(yi,self.px) # ノイズを入れる
            
            # 指定のサイズに変える
            xi = resize(xi,(self.px,self.px),anti_aliasing=True,mode='constant')
            yi = resize(yi,(self.px,self.px),anti_aliasing=True,mode='constant')
            xi = torch.Tensor(xi.transpose(2,0,1))
            yi = torch.Tensor(yi.transpose(2,0,1))
            x.append(xi)
            y.append(yi)
            
        x = torch.stack(x)
        y = torch.stack(y)
        self.i_iter += self.n_batch
        return x,y


## 画像変換器 ##

class Dinonet:
    def __init__(self,net,cn,hozon_folder,gakushuuritsu=1e-3,gpu=1):
        self.gakushuuritsu = gakushuuritsu
        self.cn = cn
        self.net = net(cn=cn)
        self.opt = torch.optim.Adam(self.net.parameters(),lr=gakushuuritsu)
        if(gpu):
            # GPUを使う場合
            self.dev = torch.device('cuda')
            self.net.cuda()
        else:
            self.dev = torch.device('cpu')
            
        self.hozon_folder = hozon_folder
        # 保存のフォルダが元々なければ予め作っておく
        if(not os.path.exists(hozon_folder)):
            os.mkdir(hozon_folder)
        
        # 訓練されたパラメータと結果を保存するファイル
        netparam_file = os.path.join(hozon_folder,'netparam.pkl')
        if(os.path.exists(netparam_file)):
            # セーブしておいたデータがすでにある場合は先回から続き
            s = torch.load(netparam_file)
            self.net.load_state_dict(s['w'])
            self.opt.load_state_dict(s['o'])
            self.mounankai = s['n']
            self.sonshitsu = s['l']
        else:
            # 最初から開始
            self.mounankai = 0
            self.sonshitsu = []
    
    def gakushuu(self,dalo_kunren,dalo_kenshou,n_kurikaeshi,n_kaku=5,yokukataru=10):
        print('訓練:%d枚 | 検証%d枚'%(dalo_kunren.len,dalo_kenshou.len))
        t0 = time.time()
        kenshou_data = []
        for da_ken in dalo_kenshou:
            kenshou_data.append(da_ken)
        dalo_kunren.random = True
        print('画像の準備に%.3f分かかった'%((time.time()-t0)/60))
        print('==学習開始==')
        
        t0 = time.time()
        # 何回も繰り返して訓練する
        for kaime in range(self.mounankai,self.mounankai+n_kurikaeshi):
            # ミニバッチ開始
            for i_batch,(x,y) in enumerate(dalo_kunren):
                z = self.net(x.to(self.dev))
                sonshitsu = mse(z,y.to(self.dev)) # 訓練データの損失
                self.opt.zero_grad()
                sonshitsu.backward()
                self.opt.step()
                
                # 検証データにテスト
                if((i_batch+1)%int(np.ceil(dalo_kunren.nkai/yokukataru))==0 or i_batch==dalo_kunren.nkai-1):
                    self.net.eval()
                    sonshitsu = []
                    if(n_kaku):
                        gazou = []
                        n_kaita = 0
                    
                    for x,y in kenshou_data:
                        z = self.net(x.to(self.dev))
                        # 検証データの損失
                        sonshitsu.append(mse(z,y.to(self.dev)).item())
                        # 検証データからできた一部の画像を書く
                        if(n_kaita<n_kaku):
                            x = x.numpy().transpose(0,2,3,1) # 入力
                            y = y.numpy().transpose(0,2,3,1) # 模範
                            z = np.clip(z.cpu().detach().numpy(),0,1).transpose(0,2,3,1) # 出力
                            for i,(xi,yi,zi) in enumerate(zip(x,y,z)):
                                # [入力、出力、模範]
                                gazou.append(np.vstack([xi,zi,yi]))
                                n_kaita += 1
                                if(n_kaita>=n_kaku):
                                    break
                    sonshitsu = np.mean(sonshitsu)
                    
                    if(n_kaku):
                        gazou = np.hstack(gazou)
                        imsave(os.path.join(self.hozon_folder,'kekka%03d.jpg'%(kaime+1)),gazou)
                    
                    # 今の状態を出力する
                    print('%d:%d/%d ~ 損失:%.4e %.2f分過ぎた'%(kaime+1,i_batch+1,dalo_kunren.nkai,sonshitsu,(time.time()-t0)/60))
                    self.net.train()
            
            # ミニバッチ一回終了
            self.sonshitsu.append(sonshitsu)
            # パラメータや状態を保存する
            sd = dict(w=self.net.state_dict(),o=self.opt.state_dict(),n=kaime+1,l=self.sonshitsu)
            torch.save(sd,os.path.join(self.hozon_folder,'netparam.pkl'))
            
            # 損失(MSE)の変化を表すグラフを書く
            plt.figure(figsize=[5,4])
            plt.gca(ylabel='MSE')
            ar = np.arange(1,kaime+2)
            plt.plot(ar,self.sonshitsu,'#11aa99')
            plt.tight_layout()
            plt.savefig(os.path.join(self.hozon_folder,'graph.png'))
            plt.close()

    def __call__(self,x,n_batch=8):
        self.net.eval()
        x = torch.Tensor(x)
        y = []
        for i in range(0,len(x),n_batch):
            y.append(self.net(x[i:i+n_batch].to(self.dev)).detach().cpu())
        return torch.cat(y).numpy()


## ノイズの関数 ##

# ガウス雑音
class Gaussnoise:
    def __init__(self,a,clip=True):
        self.a = a # 雑音のサイズ
        self.clip = clip # ノイズを追加した後、[0,1]の範囲内にするかどうか
    
    def __call__(self,y,px):
        x = y + np.random.randn(*y.shape)*self.a
        if(self.clip):
            x = np.clip(x,0,1)
        return x

# 落書きノイズ
class Rakugaki:
    def __init__(self,l0,l1,n):
        self.l0 = l0 # 一番短い線の長さ / 画像のピクセル
        self.l1 = l1 # 一番長い線の長さ / 画像のピクセル
        self.n = n # 線の数
    
    def __call__(self,y,px):
        x = y.copy()
        h = min(y.shape[0],y.shape[1])
        l0 = int(h*self.l0) # 一番短い線のピクセル
        l1 = int(h*self.l1) # 一番長い線のピクセル
        
        kk = np.random.randint(0,h,[self.n,2])
        cc = np.random.randint(0,256,[self.n,3])/255.
        ll = np.random.randint(l0,l1+1,self.n)
        for k,c,l in zip(kk,cc,ll):
            p = np.zeros([l,2],dtype=int)
            p1 = np.random.randint(0,4,l-1)
            p[0] = k
            p2 = p1%2
            p3 = p1>1
            p[1:,0] += (p2==0)*np.where(p3,-1,1)
            p[1:,1] += (p2==1)*np.where(p3,-1,1)
            p = p.cumsum(0)%h
            x[p[:,0],p[:,1]] = c
        return x

# 縮小からできたノイズ
class Shukushou:
    def __init__(self,scale=0.1):
        self.scale = scale
    
    def __call__(self,y,px):
        px = int(px*self.scale)
        return resize(y,(px,px),anti_aliasing=True,mode='constant')



## 実行 ##

kunren_folder = 'kunren' # 訓練データのフォルダ
kenshou_folder = 'kenshou' # 検証データのフォルダ
hozon_folder = 'hozon' # 結果を保存するフォルダ
cn = 3 # チャネル数 (3色データ)
n_batch = 8 # バッチサイズ
px = 256 # 画像の大きさ
n_kurikaeshi = 30 # 何回繰り返すか
n_kaku = 6 # 見るために結果の画像を何枚出力する
yokukataru = 10 # 一回の訓練で何回結果を出力する

# 使うモデルを選ぶ
model = Unet
#net = DnCNN
#net = Win5RB

# ノイズを起こす関数を選ぶ
fnoise = Gaussnoise(0.3)
#fnoise = Rakugaki(0.5,5,25)
#fnoise = Shukushou(0.25)

dalo_kunren = Gazoudalo(kunren_folder,fnoise,px,n_batch) # 訓練データ
dalo_kenshou = Gazoudalo(kenshou_folder,fnoise,px,n_batch) # 検証データ
dino = Dinonet(model,cn,hozon_folder)
# 学習開始
dino.gakushuu(dalo_kunren,dalo_kenshou,n_kurikaeshi,n_kaku,yokukataru)

結果

次はあの3種のモデルで、同じく30回練習した後、検証データでテストした結果を見ます。

ガウス雑音

学習の進歩

出力できた画像。上から「入力、DnCNN、U-net、WIN4-RB、オリジナル」の順

落書き

学習の進歩

出力できた画像

縮小から復元

学習の進歩

出力できた画像

纏め

結果から見ると、どのモデルでもある程度ノイズを除去できるようです。

損失の値から見ると、ガウス雑音と縮小の場合はU-netが一番いいようですが、落書きの場合DnCNNの方が一番いいようです。

その他に、綺麗な画像無しで学習することもできるnoise2noiseという方法があります。これについてこの記事で書いています。

26
24
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
26
24

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?