はじめに
汚れた画像を綺麗にしたり、落書きを画像から消したり、ズームで品質が落ちた画像を復元したりするために、最近深層学習が使われる方法です。
原理は簡単、ノイズのある画像を変換器に入力して、綺麗な画像を出力させるだけです。
そのような変換器というのはよく畳み込みニューラルネットワーク(CNN)から成なされるものです。
この記事では、そのようなCNNをpytorchで実装してみて、結果を見てみます。
このようなCNNのモデルは色々ありますが、今回は3種のCNNモデルで試して結果を比べてみます。
概要
ノイズのある画像から綺麗な画像に変える変換器を作るためには、まずCNNモデルを決めることです。
そのCNNはノイズのある画像を入力として、出力は入力と同じ大きさの画像。
モデル通りCNNを作ったら、まず最初はCNNの中のパラメータはランダムだから最初の出力は出鱈目であるはず。
その結果は模範の画像と比べてMSE(mean square error)を計算して誤差逆伝播法でパラメータを更新していって、出力は綺麗になっていく。
結局、出力の画像は模範とほぼ変わらないものになります。
学習完了した後、学習に使われなかった画像を入力しても、同じようにノイズが消されるでしょう。(そうじゃない場合は過学習になっているということ)
もっと詳しくはqiitaで参考になる記事がたくさんあります。
使う画像データの準備
今回使う画像はこの前DCGANの実装を試したと同じくsafebooruから取得したものです。
詳しくはDCGANの記事で
画像をダウンロードして、256x256サイズに切り取って、このような画像ができます。
そしてこのように3つの方法で画像をいじります。
- ガウス雑音
- 落書き
- 縮小して再び拡大
どちらにせよ、ノイズを加えて画像の品質を落とすことです。そのような画像をCNNで復元してみます。
ただし、CNNの学習をする時に予めこのような画像を予め準備するのではなく、毎回綺麗なサンプルからランダムでノイズを入れて使います。こうやって毎回ランダムで違うノイズからなる画像が使われます。
使うモデル
今回使おうとしたのは4種のモデル
U-net
論文 Ronneberger & Fischer & Brox 2015
コード https://github.com/DuFanXin/U-net , https://github.com/milesial/Pytorch-UNet
RED-Net
論文 Mao & Shen & Yang 2016
コード https://github.com/ved27/RED-net
DnCNN
論文 Zhang et al. 2017
コード https://github.com/cszn/DnCNN
WIN5-RB
論文 Liu & Fang 2017
コード https://github.com/cswin/WIN
ただしRED-Netは結局大きすぎて、使おうとしてみましたが、時間がかかりすぎて効果的ではないため、結局諦めました。
各モデルの構造を絵で説明します。
画像はadobe illustratorのextendscriptで書いてみたものです。
U-net
RED-Net
DnCNN
WIN5-RB
実装
今回実装してみたコード
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.image import imsave,imread
from skimage.transform import resize
import os,time,torch
from torch import nn
mse = nn.MSELoss()
## ニューラルネットワークモデル ##
class Unet(nn.Module):
def __init__(self,cn=3):
super(Unet,self).__init__()
self.copu1 = nn.Sequential(
nn.Conv2d(cn,48,3,stride=1,padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(48,48,3,padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2)
)
for i in range(2,6):
self.add_module('copu%d'%i,
nn.Sequential(
nn.Conv2d(48,48,3,stride=1,padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2)
)
)
self.coasa1 = nn.Sequential(
nn.Conv2d(48,48,3,stride=1,padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(48,48,3,stride=2,padding=1,output_padding=1)
)
self.coasa2 = nn.Sequential(
nn.Conv2d(96,96,3,stride=1,padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(96,96,3,stride=1,padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(96,96,3,stride=2,padding=1,output_padding=1)
)
for i in range(3,6):
self.add_module('coasa%d'%i,
nn.Sequential(
nn.Conv2d(144,96,3,stride=1,padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(96,96,3,stride=1,padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(96,96,3,stride=2,padding=1,output_padding=1)
)
)
self.coli = nn.Sequential(
nn.Conv2d(96+cn,64,3,stride=1,padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64,32,3,stride=1,padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(32,cn,3,stride=1,padding=1),
nn.LeakyReLU(0.1)
)
for l in self.modules(): # 重みの初期値
if(type(l) in (nn.ConvTranspose2d,nn.Conv2d)):
nn.init.kaiming_normal_(l.weight.data)
l.bias.data.zero_()
def forward(self,x):
x1 = self.copu1(x)
x2 = self.copu2(x1)
x3 = self.copu3(x2)
x4 = self.copu4(x3)
x5 = self.copu5(x4)
z = self.coasa1(x5)
z = self.coasa2(torch.cat((z,x4),1))
z = self.coasa3(torch.cat((z,x3),1))
z = self.coasa4(torch.cat((z,x2),1))
z = self.coasa5(torch.cat((z,x1),1))
return self.coli(torch.cat((z,x),1))
class DnCNN(nn.Sequential):
def __init__(self,cn=3):
super(DnCNN,self).__init__()
self.add_module('coba1',nn.Sequential(
nn.Conv2d(cn,64,3,stride=1,padding=1),
nn.ReLU(inplace=True),
)
)
for i in range(2,17):
self.add_module('coba%d'%i,
nn.Sequential(
nn.Conv2d(64,64,3,stride=1,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
)
)
self.add_module('coba17',nn.Conv2d(64,cn,3,stride=1,padding=1))
for l in self.modules(): # 重みの初期値
if(type(l)==nn.Conv2d):
nn.init.kaiming_normal_(l.weight.data)
l.bias.data.zero_()
elif(type(l)==nn.BatchNorm2d):
l.weight.data.fill_(1)
l.bias.data.zero_()
class Win5RB(nn.Module):
def __init__(self,cn=3):
super(Win5RB,self).__init__()
inc = cn
for i in range(1,5):
self.add_module('coba%d'%i,
nn.Sequential(
nn.Conv2d(inc,64,7,stride=1,padding=3),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
)
)
inc = 64
self.add_module('coba5',nn.Sequential(
nn.Conv2d(64,cn,7,stride=1,padding=3),
nn.BatchNorm2d(cn)
)
)
for l in self.modules(): # 重みの初期値
if(type(l)==nn.Conv2d):
nn.init.kaiming_normal_(l.weight.data)
l.bias.data.zero_()
elif(type(l)==nn.BatchNorm2d):
l.weight.data.fill_(1)
l.bias.data.zero_()
def forward(self,x):
z = self.coba1(x)
z = self.coba2(z)
z = self.coba3(z)
z = self.coba4(z)
return self.coba5(z) + x
class Rednet(nn.Module):
def __init__(self,cn=3):
super(Rednet,self).__init__()
self.add_module('con1',nn.Sequential(
nn.Conv2d(cn,64,7,stride=1,padding=3),
nn.ReLU(inplace=True)
)
)
for i in range(2,16):
self.add_module('con%d'%i,
nn.Sequential(
nn.Conv2d(64,64,7,stride=1,padding=3),
nn.ReLU(inplace=True)
)
)
for i in range(1,15):
self.add_module('cont%d'%i,
nn.Sequential(
nn.ConvTranspose2d(64,64,7,stride=1,padding=3),
nn.ReLU(inplace=True)
)
)
self.add_module('cont15',nn.ConvTranspose2d(64,cn,7,stride=1,padding=3))
for l in self.modules(): # 重みの初期値
if(type(l) in (nn.ConvTranspose2d,nn.Conv2d)):
nn.init.kaiming_normal_(l.weight.data)
l.bias.data.zero_()
def forward(self,x):
z = self.con1(x)
x2 = self.con2(z)
z = self.con3(x2)
x4 = self.con4(z)
z = self.con5(x4)
x6 = self.con6(z)
z = self.con7(x6)
x8 = self.con8(z)
z = self.con9(x8)
x10 = self.con10(z)
z = self.con11(x10)
x12 = self.con12(z)
z = self.con13(x12)
x14 = self.con14(z)
z = self.con15(x14)
z = self.cont1(z)
z = self.cont2(x14+z)
z = self.cont3(z)
z = self.cont4(x12+z)
z = self.cont5(z)
z = self.cont6(x10+z)
z = self.cont7(z)
z = self.cont8(x8+z)
z = self.cont9(z)
z = self.cont10(x6+z)
z = self.cont11(z)
z = self.cont12(x4+z)
z = self.cont13(z)
z = self.cont14(x2+z)
z = self.cont15(z)
return z
## データローダ ##
class Gazoudalo:
def __init__(self,folder,fnoise,px=256,n_batch=4,random=False):
self.folder = folder # データの保存したフォルダ
self.fnoise = fnoise # 使うノイズ関数
self.px = px # 画像の大きさ
self.n_batch = n_batch # バッチサイズ
self.random = random # ランダムするかどうか
self.file = np.sort(glob(os.path.join(folder,'*.jpg'))) # フォルダの中の全部の画像の名前
self.len = len(self.file) # 画像の枚数
self.nkai = int(np.ceil(self.len/n_batch)) # このバッチサイズで繰り返すことができる回数
def __iter__(self):
self.i_iter = 0
if(self.random):
np.random.seed(None)
# ランダムで順番を入れ替える
self.i_rand = np.random.permutation(self.len)
else:
np.random.seed(0)
self.i_rand = np.arange(self.len)
return self
def __next__(self):
if(self.i_iter>=self.len):
raise StopIteration
x = []
y = []
for i in self.i_rand[self.i_iter:self.i_iter+self.n_batch]:
yi = imread(self.file[i])/255. # 画像読み込み
if(self.random and np.random.random()>0.5):
yi = yi[:,::-1] # ランダムで画像の左右を逆転する
xi = self.fnoise(yi,self.px) # ノイズを入れる
# 指定のサイズに変える
xi = resize(xi,(self.px,self.px),anti_aliasing=True,mode='constant')
yi = resize(yi,(self.px,self.px),anti_aliasing=True,mode='constant')
xi = torch.Tensor(xi.transpose(2,0,1))
yi = torch.Tensor(yi.transpose(2,0,1))
x.append(xi)
y.append(yi)
x = torch.stack(x)
y = torch.stack(y)
self.i_iter += self.n_batch
return x,y
## 画像変換器 ##
class Dinonet:
def __init__(self,net,cn,hozon_folder,gakushuuritsu=1e-3,gpu=1):
self.gakushuuritsu = gakushuuritsu
self.cn = cn
self.net = net(cn=cn)
self.opt = torch.optim.Adam(self.net.parameters(),lr=gakushuuritsu)
if(gpu):
# GPUを使う場合
self.dev = torch.device('cuda')
self.net.cuda()
else:
self.dev = torch.device('cpu')
self.hozon_folder = hozon_folder
# 保存のフォルダが元々なければ予め作っておく
if(not os.path.exists(hozon_folder)):
os.mkdir(hozon_folder)
# 訓練されたパラメータと結果を保存するファイル
netparam_file = os.path.join(hozon_folder,'netparam.pkl')
if(os.path.exists(netparam_file)):
# セーブしておいたデータがすでにある場合は先回から続き
s = torch.load(netparam_file)
self.net.load_state_dict(s['w'])
self.opt.load_state_dict(s['o'])
self.mounankai = s['n']
self.sonshitsu = s['l']
else:
# 最初から開始
self.mounankai = 0
self.sonshitsu = []
def gakushuu(self,dalo_kunren,dalo_kenshou,n_kurikaeshi,n_kaku=5,yokukataru=10):
print('訓練:%d枚 | 検証%d枚'%(dalo_kunren.len,dalo_kenshou.len))
t0 = time.time()
kenshou_data = []
for da_ken in dalo_kenshou:
kenshou_data.append(da_ken)
dalo_kunren.random = True
print('画像の準備に%.3f分かかった'%((time.time()-t0)/60))
print('==学習開始==')
t0 = time.time()
# 何回も繰り返して訓練する
for kaime in range(self.mounankai,self.mounankai+n_kurikaeshi):
# ミニバッチ開始
for i_batch,(x,y) in enumerate(dalo_kunren):
z = self.net(x.to(self.dev))
sonshitsu = mse(z,y.to(self.dev)) # 訓練データの損失
self.opt.zero_grad()
sonshitsu.backward()
self.opt.step()
# 検証データにテスト
if((i_batch+1)%int(np.ceil(dalo_kunren.nkai/yokukataru))==0 or i_batch==dalo_kunren.nkai-1):
self.net.eval()
sonshitsu = []
if(n_kaku):
gazou = []
n_kaita = 0
for x,y in kenshou_data:
z = self.net(x.to(self.dev))
# 検証データの損失
sonshitsu.append(mse(z,y.to(self.dev)).item())
# 検証データからできた一部の画像を書く
if(n_kaita<n_kaku):
x = x.numpy().transpose(0,2,3,1) # 入力
y = y.numpy().transpose(0,2,3,1) # 模範
z = np.clip(z.cpu().detach().numpy(),0,1).transpose(0,2,3,1) # 出力
for i,(xi,yi,zi) in enumerate(zip(x,y,z)):
# [入力、出力、模範]
gazou.append(np.vstack([xi,zi,yi]))
n_kaita += 1
if(n_kaita>=n_kaku):
break
sonshitsu = np.mean(sonshitsu)
if(n_kaku):
gazou = np.hstack(gazou)
imsave(os.path.join(self.hozon_folder,'kekka%03d.jpg'%(kaime+1)),gazou)
# 今の状態を出力する
print('%d:%d/%d ~ 損失:%.4e %.2f分過ぎた'%(kaime+1,i_batch+1,dalo_kunren.nkai,sonshitsu,(time.time()-t0)/60))
self.net.train()
# ミニバッチ一回終了
self.sonshitsu.append(sonshitsu)
# パラメータや状態を保存する
sd = dict(w=self.net.state_dict(),o=self.opt.state_dict(),n=kaime+1,l=self.sonshitsu)
torch.save(sd,os.path.join(self.hozon_folder,'netparam.pkl'))
# 損失(MSE)の変化を表すグラフを書く
plt.figure(figsize=[5,4])
plt.gca(ylabel='MSE')
ar = np.arange(1,kaime+2)
plt.plot(ar,self.sonshitsu,'#11aa99')
plt.tight_layout()
plt.savefig(os.path.join(self.hozon_folder,'graph.png'))
plt.close()
def __call__(self,x,n_batch=8):
self.net.eval()
x = torch.Tensor(x)
y = []
for i in range(0,len(x),n_batch):
y.append(self.net(x[i:i+n_batch].to(self.dev)).detach().cpu())
return torch.cat(y).numpy()
## ノイズの関数 ##
# ガウス雑音
class Gaussnoise:
def __init__(self,a,clip=True):
self.a = a # 雑音のサイズ
self.clip = clip # ノイズを追加した後、[0,1]の範囲内にするかどうか
def __call__(self,y,px):
x = y + np.random.randn(*y.shape)*self.a
if(self.clip):
x = np.clip(x,0,1)
return x
# 落書きノイズ
class Rakugaki:
def __init__(self,l0,l1,n):
self.l0 = l0 # 一番短い線の長さ / 画像のピクセル
self.l1 = l1 # 一番長い線の長さ / 画像のピクセル
self.n = n # 線の数
def __call__(self,y,px):
x = y.copy()
h = min(y.shape[0],y.shape[1])
l0 = int(h*self.l0) # 一番短い線のピクセル
l1 = int(h*self.l1) # 一番長い線のピクセル
kk = np.random.randint(0,h,[self.n,2])
cc = np.random.randint(0,256,[self.n,3])/255.
ll = np.random.randint(l0,l1+1,self.n)
for k,c,l in zip(kk,cc,ll):
p = np.zeros([l,2],dtype=int)
p1 = np.random.randint(0,4,l-1)
p[0] = k
p2 = p1%2
p3 = p1>1
p[1:,0] += (p2==0)*np.where(p3,-1,1)
p[1:,1] += (p2==1)*np.where(p3,-1,1)
p = p.cumsum(0)%h
x[p[:,0],p[:,1]] = c
return x
# 縮小からできたノイズ
class Shukushou:
def __init__(self,scale=0.1):
self.scale = scale
def __call__(self,y,px):
px = int(px*self.scale)
return resize(y,(px,px),anti_aliasing=True,mode='constant')
## 実行 ##
kunren_folder = 'kunren' # 訓練データのフォルダ
kenshou_folder = 'kenshou' # 検証データのフォルダ
hozon_folder = 'hozon' # 結果を保存するフォルダ
cn = 3 # チャネル数 (3色データ)
n_batch = 8 # バッチサイズ
px = 256 # 画像の大きさ
n_kurikaeshi = 30 # 何回繰り返すか
n_kaku = 6 # 見るために結果の画像を何枚出力する
yokukataru = 10 # 一回の訓練で何回結果を出力する
# 使うモデルを選ぶ
model = Unet
#net = DnCNN
#net = Win5RB
# ノイズを起こす関数を選ぶ
fnoise = Gaussnoise(0.3)
#fnoise = Rakugaki(0.5,5,25)
#fnoise = Shukushou(0.25)
dalo_kunren = Gazoudalo(kunren_folder,fnoise,px,n_batch) # 訓練データ
dalo_kenshou = Gazoudalo(kenshou_folder,fnoise,px,n_batch) # 検証データ
dino = Dinonet(model,cn,hozon_folder)
# 学習開始
dino.gakushuu(dalo_kunren,dalo_kenshou,n_kurikaeshi,n_kaku,yokukataru)
結果
次はあの3種のモデルで、同じく30回練習した後、検証データでテストした結果を見ます。
ガウス雑音
学習の進歩
出力できた画像。上から「入力、DnCNN、U-net、WIN4-RB、オリジナル」の順
落書き
学習の進歩
出力できた画像
縮小から復元
学習の進歩
出力できた画像
纏め
結果から見ると、どのモデルでもある程度ノイズを除去できるようです。
損失の値から見ると、ガウス雑音と縮小の場合はU-netが一番いいようですが、落書きの場合DnCNNの方が一番いいようです。
その他に、綺麗な画像無しで学習することもできるnoise2noiseという方法があります。これについてこの記事で書いています。