Edited at

Neural Network Console を用いた超解像処理の実装

More than 1 year has passed since last update.


1. はじめに

 最近すごいと話題になっていたSony製のディープラーニングのフレームワーク「Neural Network Console(以下、NNC)」を触ってみました。

 ある秋の日に、私はNNCの使い方を学びたくてQiitaを眺めていましたが、画像の分別器を作成するチュートリアルまでは紹介されているものの、最終的に画像を生成するようなものは見つけられませんでした。

 また、最近購入したkindle本の「TensorFlowはじめました2」が非常に素晴らしい内容だったので、これをNNCに移植してみようとしてみました。


2. デモ

以下の動画の通りです。



3. 開発環境


  • OS -Windows10

  • CPU -Intel(R) Core(TM) i7-6700K

  • GPU -GTX1080

  • メモリ -32.0 GB

  • ディスク -SSD 256 GB


4. ネットワーク全体図

3倍に拡大してボケてしまった画像を33x33で切り抜き、21x21のクリアな画像に変換していきます。


イメージ

アセット 1.png


5. Neural Network Consoleの導入

この記事がよくまとまっていると思います。


6. NNCの問題点


1. 画像が正解データのデータセットを作成できない

 NNCには非常にありがたいことに、画像をラベルと同じ名前にしたフォルダにまとめておくと、自動にデータセットとしてcsvファイルを作成してくれる便利な機能があります。しかし、一つの画像をもう一つの画像に対応させることはできません。

(もしかしたらそういう機能があるのかもしれませんが、僕は見つけられませんでした。)


2. 試してみた

以上のようなことがあったので、inputデータとlabelデータをどっちも画像にしたcsvファイルを自作して反応するか試してみました。


csvの中身

csv.png


データセットタブ

csv.png

いけそう

認識はしてくれたので、大丈夫だと信じて進んでいきます。


7. python3.6とopenCV3.2の導入

この記事が参考になりました。


8. データセット作成


考え方


正解データ


  • 21x21のクリアな画像

000013.jpg


inputデータ


  • 3倍に拡大された33x33のボケた画像

  • paddingとして正解データよりも広い範囲を切り抜いている。

000013l.jpg


csvファイル

正解データとinputデータを作成しながら、どんどん書き込んでいきます。


プログラムの構成


  1. 元画像を33x33で切り抜くことができるサイズに切り抜く。

  2. 33x33に切り抜く(cropped_img)

  3. cropped_imgを11x11に圧縮した後、33x33に拡大したものをinputデータにする。

  4. cropped_imgをちょうど中央で21x21に切り取り、正解データにする。


コード


datasetCreator.py

# coding utf-8


import os
import cv2
import sys
import csv
import random
import shutil

# モード選択
modeD = False
modeC = False
modeE = False

# データセット作成時の変数
o_path = 'data\original\\'
t_path = 'data\\tr_img\\'
l_path = 'data\lr_img\\'
files0 = []
files1 = []
texts0 = []
texts1 = []
num = 0
n = 0

# CSV作成時の変数
l_add = 'lr_img/'
t_add = 'tr_img/'
train = 'data\imgUp_train.csv'
test = 'data\imgUp_test.csv'
execu = 'data\imgUp_execu.csv'
rate = 70

# Executor作成時の変数
in_path = 'data\executor\input\\'
ou_path = 'data\executor\output\\'
fa_path = 'data\executor\\fake\\'
in_add = 'executor/input/'
fa_add = 'executor/fake/'
sam_add = 'data\executor\sample.jpg'

# MODE確認
sys.stdout.write("データセットを作成しますか?(y/n)⇒ ")
mode = input()
if mode == 'y':
modeD = True
elif mode == 'n':
modeD == False
else:
sys.stdout.write("不適切な文字が与えられました\n")
sys.exit()
sys.stdout.write("CSVを作成しますか?(y/n)⇒ ")
mode = input()
if mode == 'y':
modeC = True
elif mode == 'n':
modeC == False
else:
sys.stdout.write("不適切な文字が与えられました\n")
sys.exit()
sys.stdout.write("Executorデータを作成しますか?(y/n)⇒ ")
mode = input()
if mode == 'y':
modeE = True
elif mode == 'n':
modeE == False
else:
sys.stdout.write("不適切な文字が与えられました\n")
sys.exit()

for x in os.listdir(o_path):
if os.path.isfile(o_path + x):
files0.append(x)
for y in files0:
if(y[-4:] == '.jpg'):
texts0.append(y)

sys.stdout.write("Seting up project folder\n")
if modeD:
if os.path.exists(t_path):
shutil.rmtree(t_path)
sys.stdout.write("Removed tr_img\n")
if os.path.exists(l_path):
shutil.rmtree(l_path)
sys.stdout.write("Removed lr_img\n")
if modeE:
if os.path.exists(in_path):
shutil.rmtree(in_path)
sys.stdout.write("Removed input\n")
if os.path.exists(ou_path):
shutil.rmtree(ou_path)
sys.stdout.write("Removed output\n")
if os.path.exists(fa_path):
shutil.rmtree(fa_path)
sys.stdout.write("Removed fake\n")
if modeD:
os.mkdir(t_path)
os.mkdir(l_path)
if modeE:
os.mkdir(in_path)
os.mkdir(ou_path)
os.mkdir(fa_path)
sys.stdout.write("Setup complete\n\n")

if modeD:
sys.stdout.write("Creating Dataset\n")
for x in texts0:
img = cv2.imread(o_path+x, cv2.IMREAD_UNCHANGED)

if img is None:
sys.stdout.write("Failed to load image file.\n")
sys.exit(1)

if len(img.shape) == 3:
height, width, channels = img.shape[:3]
else:
height, width = img.shape[:2]
channels = 1

numX = width // 33
numY = height // 33

for i in range(numX):
for j in range(numY):
cropped_img = img[33*j:33*(j+1),33*i:33*(i+1)]
small_img = cv2.resize(cropped_img,(11,11))
tr_img = cropped_img[6:27,6:27]
lr_img = cv2.resize(small_img,(33,33))

cv2.imwrite(t_path + ('%06d'%num) + ".jpg",tr_img)
cv2.imwrite(l_path + ('%06d'%num) + ".jpg",lr_img)
num = num + 1

n = n + 1
sys.stdout.write("\r" + str(n) + "/" + str(len(texts0)))
num = 0
n = 0
sys.stdout.write("\nDataset Created!\n\n")

if modeC:
sys.stdout.write("Creating CSV\n")
for x in os.listdir(l_path):
if os.path.isfile(l_path + x):
files1.append(x)
for y in files1:
if(y[-4:] == '.jpg'):
texts1.append(y)

full = len(texts1)
rate = (full*rate)//100
li = list(range(full))
random.shuffle(li)

with open(train, "w") as csvfile:
writer = csv.writer(csvfile, delimiter=',',lineterminator="\n")
writer.writerow(["x:image","y:image"])
for i in range(rate):
writer.writerow([l_add+'%06d'%li[i]+".jpg",t_add+'%06d'%li[i]+".jpg"])

with open(test, "w") as csvfile:
writer = csv.writer(csvfile,delimiter=',',lineterminator="\n")
writer.writerow(["x:image","y:image"])
for i in range(full - rate):
writer.writerow([l_add+'%06d'%li[i+rate]+".jpg",t_add+'%06d'%li[i+rate]+".jpg"])
num = 0
n = 0
sys.stdout.write("CSV Created!\n\n")

if modeE:
sys.stdout.write("Creating Executor\n")
img = cv2.imread(sam_add, cv2.IMREAD_UNCHANGED)

if img is None:
sys.stdout.write("Failed to load image file.\n")
sys.exit(1)
if len(img.shape) == 3:
height, width, channels = img.shape[:3]
else:
height, width = img.shape[:2]
channels = 1

big_img = cv2.resize(img,(width*3,height*3))
numX = (width*3-12) // 21
numY = (height*3-12) // 21
sum_num = numX*numY

src_img = big_img[0:numY*21+12,0:numX*21+12]

for i in range(numX):
for j in range(numY):
input_img = src_img[21*j:21*j+33,21*i:21*i+33]
fake_img = input_img[6:27,6:27]

cv2.imwrite(in_path + ('%06d'%num) + ".jpg",input_img)
cv2.imwrite(fa_path + ('%06d'%num) + ".jpg",fake_img)
num = num + 1

sys.stdout.write("\r" + str(num) + "/" + str(sum_num))

with open(execu, "w") as csvfile:
writer = csv.writer(csvfile, delimiter=',',lineterminator="\n")
writer.writerow(["x:image","y:image"])
for i in range(sum_num):
writer.writerow([in_add+'%06d'%i+".jpg",fa_add+'%06d'%i+".jpg"])

sys.stdout.write("\nExecutor Created!\n\n")
sys.stdout.write("\nCompleted!!")



9. NNCによる深層学習


1. 実装

詳しくは「TensorFlowはじめました2」を購入して読んで欲しいですが、2回の畳み込みと最後の活性化関数をシグモイド関数にしてNNを組んでいます。

map.png

ダブルクリックの繰り返しだけで簡単にネットワークが構築できます。

本当に簡単で素晴らしい!


2. トレーニング

この本では莫大な回数で学習させていましたが、同様で行うと40日かかると出力されたため、今回は100回しか学習を回していません。log.png

このようなグラフが簡単に手に入るものNNCの良いところですね。


3. テスト

csv.png

いいかんじ


10. 画像の復元

バラバラになってしまった画像を元の形にくっつけていきます。


コード


reCreator.py

import os

import cv2
import sys

path = 'data\executor\output\\'
fake = 'data\executor\\fake\\'
sam = 'data\executor\sample.jpg'
res = 'data\executor\\result.jpg'
files = []
texts = []
args = sys.argv

isFa = int(args[1])
n = 0

for x in os.listdir(path):
if os.path.isfile(path + x):
files.append(x)
for y in files:
if(y[-4:] == '.png'):
texts.append(y)

img = cv2.imread(sam, cv2.IMREAD_UNCHANGED)
if img is None:
sys.stdout.write("Failed to load image file.\n")
sys.exit(1)
if len(img.shape) == 3:
height, width, channels = img.shape[:3]
else:
height, width = img.shape[:2]
channels = 1
numX = (width*3-12) // 21
numY = (height*3-12) // 21

n = numX*numY;

if isFa == 0:
sys.stdout.write("Super Resolution Mode.\n")
else:
sys.stdout.write("Fake Mode.\n")

for x in range(numX):
for y in range(numY):
num = y+x*numY;
if isFa == 0:
img = cv2.imread(path+str(num)+".png", cv2.IMREAD_UNCHANGED)
if isFa == 1:
img = cv2.imread(fake+('%06d'%num)+".jpg", cv2.IMREAD_UNCHANGED)
res = 'data\executor\\fake.jpg'
if img is None:
sys.stdout.write("Failed to load image file.")
sys.exit(1)
if y==0:
imgY = img
else:
imgY = cv2.vconcat([imgY, img])

sys.stdout.flush()
sys.stdout.write("\r" + str(num) + "/" + str(n))

if x==0:
imgX = imgY
else:
imgX = cv2.hconcat([imgX, imgY])

cv2.imwrite(res,imgX)
sys.stdout.write("Completed!!")



11. 結果と比較

結果の画像を載せようとしましたが、重すぎると言われました。


比較(左:出力結果 右:単純拡大)

出力結果とopenCVで単純に拡大したものをズームして並べてスクリーンショットで撮影してます。

ここでopenCVを使って並ばせることは思い浮かばなかったです。

DLOQkSMUMAIstlX.jpg

だいぶクリアになったんじゃないでしょうか。


12. まとめ

初めて挑戦したにしては非常に良い結果が出たのではないかと満足しています。次はもっと複雑なNN(DCGANのような)にも挑戦しようと思います。

NNCに関しては、非常に使いやすくて満足しています。

しかし一つのNNが一つのプロジェクトファイルとして保存されているため、pythonから直接起動などができないのが残念です。どうやら学習したパラメータは.h5ファイルで保存されているようなので、それを読み込むことさえできれば、超解像の処理を一つのpythonコードで実装できそうです。


おまけ

新しくtwitterをはじめました。フォローしていただけると幸いです。

制作物