LoginSignup
12
8

More than 1 year has passed since last update.

超解像技術-SRCNN-実装してみた(Tensorflow 2.0) 訓練フェーズ編

Last updated at Posted at 2020-05-12

概要

超解像技術とは、低解像度画像を高解像度画像にする技術です。
今回は超解像技術の一種で、比較的実装が簡単なSRCNNを実装してみました。

今回はSRCNNの訓練フェーズ編です。
次回はSRCNNの推論フェーズ編になります。
実はSRCNNの他にSRGANも実装してみたので超解像技術シリーズとして、いずれそちらも取り上げます。

環境

-Software-
Windows 10 Home
Anaconda3 64-bit(Python3.7)
Spyder
-Library-
Tensorflow 2.1.0
opencv-python 4.1.2.30
-Hardware-
CPU: Intel core i9 9900K
GPU: NVIDIA GeForce RTX2080ti
RAM: 16GB 3200MHz

参考

サイト
SRCNN論文
【Intern CV Report】超解像の歴史探訪 -2016年編-
【スパースコーディング】スパースなデータ表現の利点
Keras: 超解像
ディープラーニングで簡単な超解像をやってみた
PyTorchと超解像に入門する
画像の超解像度化をするモデル SRCNN を pytorch で実装してみた

プログラム

Githubに上げておきます。
https://github.com/himazin331/Super-resolution-CNN
リポジトリには訓練フェーズ、推論フェーズが含まれています。

今回は、データセットにGeneral-100を使いました。
デモとして使えるようにGitHubのリポジトリにデータセットも入れてあります。

ソースコード

コードが汚いのはご了承ください...

srcnn_tr.py
import tensorflow as tf
import tensorflow.keras.layers as kl
from tensorflow.python.keras import backend as K

import cv2
import numpy as np

import matplotlib.pyplot as plt

import argparse as arg
import os
import sys

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # TFメッセージ非表示


# SRCNN
class SRCNN(tf.keras.Model):
    def __init__(self, h, w):
        super(SRCNN, self).__init__()

        self.conv1 = kl.Conv2D(64, 3, padding='same', activation='relu', input_shape=(None, h, w, 3))
        self.conv2 = kl.Conv2D(32, 3, padding='same', activation='relu')
        self.conv3 = kl.Conv2D(3, 3, padding='same', activation='relu')

    def call(self, x):
        h1 = self.conv1(x)
        h2 = self.conv2(h1)
        h3 = self.conv3(h2)

        return h3


# 学習
class trainer(object):
    def __init__(self, h, w):
        self.model = SRCNN(h, w)

        self.model.compile(optimizer=tf.keras.optimizers.Adam(),
                            loss=tf.keras.losses.MeanSquaredError(),
                            metrics=[self.psnr])

    def train(self, lr_imgs, hr_imgs, out_path, batch_size, epochs):
        # 学習
        his = self.model.fit(lr_imgs, hr_imgs, batch_size=batch_size, epochs=epochs)

        print("___Training finished\n\n")

        # パラメータ保存
        print("___Saving parameter...")
        self.model.save_weights(out_path)
        print("___Successfully completed\n\n")

        return his, self.model

    # PSNR(ピーク信号対雑音比)
    def psnr(self, h3, hr_imgs):
        return -10 * K.log(K.mean(K.flatten((h3 - hr_imgs))**2)) / np.log(10)


# データセット作成
def create_dataset(data_dir, h, w, mag):
    print("\n___Creating a dataset...")

    prc = ['/', '-', '\\', '|']
    cnt = 0

    # 画像データの個数
    print("Number of image in a directory: {}".format(len(os.listdir(data_dir))))

    lr_imgs = []
    hr_imgs = []

    for c in os.listdir(data_dir):
        d = os.path.join(data_dir, c)

        _, ext = os.path.splitext(c)
        if ext.lower() == '.db':
            continue
        elif ext.lower() != '.bmp':
            continue

        # 読込、リサイズ(高解像画像)
        img = cv2.imread(d)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (w, h))

        # 低解像度画像
        img_low = cv2.resize(img, (int(w / mag), int(h / mag)))
        img_low = cv2.resize(img_low, (w, h))

        lr_imgs.append(img_low)
        hr_imgs.append(img)

        cnt += 1

        print("\rLoading a LR-images and HR-images...{}    ({} / {})".format(prc[cnt % 4], cnt, len(os.listdir(data_dir))), end='')

    print("\rLoading a LR-images and HR-images...Done    ({} / {})".format(cnt, len(os.listdir(data_dir))), end='')

    # 正規化
    lr_imgs = tf.convert_to_tensor(lr_imgs, np.float32)
    lr_imgs /= 255
    hr_imgs = tf.convert_to_tensor(hr_imgs, np.float32)
    hr_imgs /= 255

    print("\n___Successfully completed\n")
    return lr_imgs, hr_imgs


# PSNR, 損失値グラフ出力
def graph_output(history):
    # PSNRグラフ
    plt.plot(history.history['psnr'])
    plt.title('Model PSNR')
    plt.ylabel('PSNR')
    plt.xlabel('Epoch')
    plt.legend(['Train'], loc='upper left')
    plt.show()

    # 損失値グラフ
    plt.plot(history.history['loss'])
    plt.title('Model loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train'], loc='upper left')
    plt.show()


def main():
    # コマンドラインオプション作成
    parser = arg.ArgumentParser(description='Super-resolution CNN training')
    parser.add_argument('--data_dir', '-d', type=str, default=None,
                        help='画像フォルダパスの指定(未指定ならエラー)')
    parser.add_argument('--out', '-o', type=str,
                        default=os.path.dirname(os.path.abspath(__file__)),
                        help='パラメータの保存先指定(デフォルト値=./srcnn.h5')
    parser.add_argument('--batch_size', '-b', type=int, default=32,
                        help='ミニバッチサイズの指定(デフォルト値=32)')
    parser.add_argument('--epoch', '-e', type=int, default=3000,
                        help='学習回数の指定(デフォルト値=3000)')
    parser.add_argument('--he', '-he', type=int, default=256,
                        help='リサイズの高さ指定(デフォルト値=256)')
    parser.add_argument('--wi', '-wi', type=int, default=256,
                        help='リサイズの指定(デフォルト値=256)')
    parser.add_argument('--mag', '-m', type=int, default=2,
                        help='縮小倍率の指定(デフォルト値=2)')
    args = parser.parse_args()

    # 画像フォルダパス未指定->例外
    if args.data_dir is None:
        print("\nException: Folder not specified.\n")
        sys.exit()
    # 存在しない画像フォルダ指定時->例外
    if os.path.exists(args.data_dir) is False:
        print("\nException: Folder \"{}\" is not found.\n".format(args.data_dir))
        sys.exit()
    # 幅高さ、縮小倍率いずれかに0が入力された時->例外
    if args.he == 0 or args.wi == 0 or args.mag == 0:
        print("\nException: Invalid value has been entered.\n")
        sys.exit()

    # 出力フォルダの作成(フォルダが存在する場合は作成しない)
    os.makedirs(args.out, exist_ok=True)
    out_path = os.path.join(args.out, "srcnn.h5")

    # 設定情報出力
    print("=== Setting information ===")
    print("# Images folder: {}".format(os.path.abspath(args.data_dir)))
    print("# Output folder: {}".format(out_path))
    print("# Minibatch-size: {}".format(args.batch_size))
    print("# Epoch: {}".format(args.epoch))
    print("")
    print("# Height: {}".format(args.he))
    print("# Width: {}".format(args.wi))
    print("# Magnification: {}".format(args.mag))
    print("===========================\n")

    # データセット作成
    lr_imgs, hr_imgs = create_dataset(args.data_dir, args.he, args.wi, args.mag)

    # 学習開始
    print("___Start training...")
    Trainer = trainer(args.he, args.wi)
    his, model = Trainer.train(lr_imgs, hr_imgs, out_path=out_path, batch_size=args.batch_size, epochs=args.epoch)

    # PSNR, 損失値グラフ出力、保存
    graph_output(his)


if __name__ == '__main__':
    main()

実行結果

Epoch数を3000、ミニバッチサイズを32としました。

下のグラフはPSNR(ピーク信号対雑音比)の記録です。詳細は後述します。
PSNR 30dbが天井ですね。。。

image.png
下のグラフは損失値の記録です。
image.png

なお、これらのグラフは保存されません。

コマンド
python srcnn_tr.py -d <フォルダ> -e <学習回数> -b <バッチサイズ>
                   (-o <保存先> -he <高さ> -wi <幅> -m <縮小倍率(整数)>)

説明

コードの説明をしていきます。

ネットワークモデル

SRCNNクラス
# SRCNN
class SRCNN(tf.keras.Model):
    def __init__(self, h, w):
        super(SRCNN, self).__init__()

        self.conv1 = kl.Conv2D(64, 3, padding='same', activation='relu', input_shape=(None, h, w, 3))
        self.conv2 = kl.Conv2D(32, 3, padding='same', activation='relu')
        self.conv3 = kl.Conv2D(3, 3, padding='same', activation='relu')

    def call(self, x):
        h1 = self.conv1(x)
        h2 = self.conv2(h1)
        h3 = self.conv3(h2)

        return h3

よくあるCNNと違うところは、出力チャンネルが段々大きくなっていくのが一般的なのに対し、
SRCNNの場合は、出力チャンネルを段々小さくしていくという点と、全結合層がない点ですね。

畳み込み層は3層であるのが一般的です。

1層目は、パッチ抽出と低解像度空間におけるスパース表現を行います。
2層目は、1層目で獲得した表現の高解像度空間に対する非線形写像を行います。
3層目は、高解像度画像の再構成を行います。
(【Intern CV Report】超解像の歴史探訪 -2016年編-より)

スパース表現(スパースコーディング)とはデータを表現するための辞書を用意し、その要素のできるだけ少ない組み合わせでデータを表現することを言うそうです。(【スパースコーディング】スパースなデータ表現の利点より)

スパース表現について、もう少し簡単に言うと、入力画像に対して少ない数の特徴マップを組み合わせて、どれだけリアルに近づけられる(近似精度)かというもの。
多くの特徴マップを組み合わせたほうが近似精度が向上する傾向にあるが、スパース表現ではあえてこれをせず、少ない要素を用いることで意味のある表現を取り出すことができます。
すなわち、データを表現するにはどの要素がどの程度有用なのかをはっきりさせるということだそうです。


データセット作成

必要とするデータは高解像度画像のみで大丈夫です。
低解像度画像は高解像度画像から作成します。

create_dataset関数
# データセット作成
def create_dataset(data_dir, h, w, mag):
    print("\n___Creating a dataset...")

    prc = ['/', '-', '\\', '|']
    cnt = 0

    # 画像データの個数
    print("Number of image in a directory: {}".format(len(os.listdir(data_dir))))

    lr_imgs = []
    hr_imgs = []

    for c in os.listdir(data_dir):
        d = os.path.join(data_dir, c)

        _, ext = os.path.splitext(c)
        if ext.lower() == '.db':
            continue
        elif ext.lower() != '.bmp':
            continue

        # 読込、リサイズ(高解像画像)
        img = cv2.imread(d)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (w, h))

        # 低解像度画像
        img_low = cv2.resize(img, (int(w / mag), int(h / mag)))
        img_low = cv2.resize(img_low, (w, h))

        lr_imgs.append(img_low)
        hr_imgs.append(img)

        cnt += 1

        print("\rLoading a LR-images and HR-images...{}    ({} / {})".format(prc[cnt % 4], cnt, len(os.listdir(data_dir))), end='')

    print("\rLoading a LR-images and HR-images...Done    ({} / {})".format(cnt, len(os.listdir(data_dir))), end='')

    # 正規化
    lr_imgs = tf.convert_to_tensor(lr_imgs, np.float32)
    lr_imgs /= 255
    hr_imgs = tf.convert_to_tensor(hr_imgs, np.float32)
    hr_imgs /= 255

    print("\n___Successfully completed\n")
    return lr_imgs, hr_imgs

まず、画像を読み込みます。OpenCVで読み込んだ場合、画素の並びがBGRとなるため、
cv2.cvtColor(img, cv2.COLOR_BGR2RGB)でRGBに変換します。その後、指定したサイズにリサイズを行います。
これで高解像度画像の準備はひとまずOKです。

次に、低解像度画像を作成します。指定した縮小倍率で割った幅・高さに縮小します。
その後、縮小する前のサイズにリサイズし直せば作成完了です。

        # 読込、リサイズ(高解像画像)
        img = cv2.imread(d)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (w, h))

        # 低解像度画像
        img_low = cv2.resize(img, (int(w / mag), int(h / mag)))
        img_low = cv2.resize(img_low, (w, h))

図を使って表すとこんな感じです。
image.png

蛇足ですが、OpenCVのcv2.resize()の補間アルゴリズムはデフォルトではBilinearが使われます。


学習

trainerクラスで機械学習を行う前のセットアップや学習を行います。
Tensorflowで"trainer"ってあまり言わないと思うんですが、私はChainerから始めた人間なので...

まずは、インスタンスメソッドを説明します。

trainerクラス(インスタンスメソッド)
# 学習
class trainer(object):
    def __init__(self, h, w):
        self.model = SRCNN(h, w)

        self.model.compile(optimizer=tf.keras.optimizers.Adam(),
                            loss=tf.keras.losses.MeanSquaredError(),
                            metrics=[self.psnr])

インスタンス生成時にインスタンスメソッド__init__をコールして、ネットワークモデルの構築と最適化アルゴリズムを決定します。
self.model = SRCNN(h, w)でSRCNNクラスのインスタンスメソッドに高さ・幅の情報を渡しています。
モデルが構築できたら、最適化アルゴリズムと損失関数をモデルにセットします。
今回は最適化アルゴリズムはAdamにしました。損失関数は平均二乗誤差を使用してください。

低解像度画像と高解像度画像の平均二乗誤差を求め、その値を小さくしていくことで、低解像度画像が徐々に高解像度画像に近づいていくという論理です。

metrics=[self.psnr]では評価関数にPSNRをセットしています。詳細は後述します。


続いて、trainメソッドの説明です。

trainerクラス(trainメソッド)
    def train(self, lr_imgs, hr_imgs, out_path, batch_size, epochs):
        # 学習
        his = self.model.fit(lr_imgs, hr_imgs, batch_size=batch_size, epochs=epochs)

        print("___Training finished\n\n")

        # パラメータ保存
        print("___Saving parameter...")
        self.model.save_weights(out_path)
        print("___Successfully completed\n\n")

        return his, self.model

self.model.fit()に低解像度画像lr_imgsを学習データとして、高解像度画像hr_imgsを正解ラベルとして渡して
学習を開始します。
学習が終了次第、パラメータを保存します。


最後にPSNRメソッドの説明です。

PSNR(Peak signal-to-noise ratio)はピーク信号対雑音比と言われる、画像の劣化を表す評価指標です。
ピーク信号とか雑音とかなんのことだ?と思うかもしれませんが、すみません、私も専門外なので説明はできません。

この評価指標の単位は"db(デシベル)"です。
一般的に、PSNR 30db以上が綺麗に見えるらしいです。ただし、人間の感じ方とPSNR値は必ずしも一致するとは限らないので注意してください。

定義式は載せておきます。

$$PSNR = 10 \log_{10}\frac{MAX^2}{MSE}\qquad(1.1)$$

$MSE$は平均二乗誤差です。

$$MSE = \frac{1}{n} \sum_{i=1}^{n} (SR_i - HR_i)^2\qquad(2)$$

$MAX$は画素が取り得る最大値ですが、255で割って0~1に正規化しているため、$MAX$(最大値)は1となります。
式(1.1)の$MAX$に1を代入してやると、

$$PSNR = 10 \log_{10}\frac{1}{MSE}\qquad(1.2)$$

になり、商の対数の変換公式により、

$$PSNR = -10 \log_{10}MSE\qquad(1.3)$$

となります。
今回、平均二乗誤差$MSE$の計算で、tf.keras.backend.flatten()を用います。
なにかしらのtf.keras.backendの関数を使った場合、numpyの関数を使うことはできません。エラーが出ます。
なので式中のlogはtf.keras.backend.log()を使うのですが、これは常用対数ではなく自然対数です。
そのため、式(1.3)を

$$PSNR = -10 \frac{\ln MSE}{\ln 10}\qquad(1.4)$$

底の変換公式を使って式(1.4)のように式変形をする必要があります。

私は、tf.keras.backend.log()が常用対数ではないことと、恥ずかしながら数学が得意でないので、文献をみてもどうしてこのような式変形になるのか分かりませんでした。そのため備忘録として、このように細かく式変形の様子を記述しています。

式(1.4)をコードで表したのが下になります。

trainerクラス(PSNRメソッド)
    # PSNR(ピーク信号対雑音比)
    def psnr(self, h3, hr_imgs):
        return -10 * K.log(K.mean(K.flatten((h3 - hr_imgs))**2)) / np.log(10)

K.mean(K.flatten((h3 - hr_imgs))**2が$MSE$(平均二乗誤差)に当たります。
分母はnumpyの対数ですがこちらも自然対数です。なぜか、分母のlogもtf.keras.backend.log()にしようとすると、エラーになります。なぜでしょう?

まあ、こんな具合でPSNRの式が定義できました。
ちなみに、式(2)の$SR_i$と$HR_i$が同じ画像であった場合は、PSNR $+∞$dbとなります。


おわりに

超解像技術の中で簡単なSRCNNの実装法を説明してみましたが、いかがでしたでしょうか。
次回の「超解像技術-SRCNN-実装してみた(Tensorflow 2.0) 推論フェーズ編」では実際に超解像化してみます。

12
8
5

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
8