2
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

DeepLearningを用いた超解像手法/RED-Netの実装

Last updated at Posted at 2021-05-22

概要

深層学習を用いた、単一画像における超解像手法であるRED-Netの実装したので、それのまとめの記事です。
この手法は、2016年のCVPRに採択された論文で、ノイズ除去と超解像の両方で使用できる手法でもあります。

Python + Tensorflow(Keras)で実装を行いました。
今回は、2倍拡大の超解像にチャレンジしました。

今回紹介するコードはGithubにも載せています。

1. 超解像のおさらい

超解像について簡単に説明をします。

超解像とは解像度の低い画像に対して、解像度を向上させる技術のことです。
ここでいう解像度が低いとは、画素数が少なかったり、高周波成分(輪郭などの鮮鋭な部分を表す成分)がないような画像のことです。
以下の図で例を示します。(図は[論文]より引用)
image.png

(a)は原画像、(b)は画素数の少ない画像を見やすいように原画像と同じ大きさにした画像、(c)は高周波成分を含まない画像の例です。
(b)と(c)は、荒かったりぼやけていたりしていると思います。
このような状態を解像度が低い画像といいます。

そして、超解像はこのような解像度が低い画像に処理を行い、(a)のような精細な画像を出力することを目的としています。

2. Red-Netのアルゴリズム

まずはじめに、RED-Netのアルゴリズムの概要図を示します。(図は論文から引用)
image.png

同じ数のConvolution layerDeconvolution layerで構成されており、
Deconvolutionを2回行うごとに、Skip connectionConvolutionの結果を足し合わせています。
Skip connectionは、学習に伴う勾配消失の問題を防ぐために使用されています。

また、Skip connectionによる結果の足し合わせを行ったとき、以下の図の例のようにReLUの活性化関数を通します。
(図は論文から引用)
image.png

上の図では、実線がConvolution layerで破線はDeconvolution layerを表しています。

今回実装したRed-Netは、事前にbicubicで拡大処理を行います。

3. 実装したアルゴリズム

今回、実装したRed-Netのモデルは以下のように組みました。(コードの一部を抽出)

model.py
class RED_Net():
    def __init__(self, conv_num):
        #parameter set
        self.conv_num = conv_num
        self.deconv_num = conv_num    #Number of convolution layers == number of deconvolutions layers
        self.input_channels = 1       #input gray scale images
        self.filter_num = 64
        self.kernel_size = (3, 3)

        #create a list to store skip connections
        if self.conv_num % 2 == 1:
            self.skip_connection_list = (self.conv_num // 2 + 1) * [None]
        else:
            self.skip_connection_list = (self.conv_num // 2) * [None]

    def RED_Net_odd(self):
        """
        if conv_num % 2 == 1
        """
        input_shape = Input((None, None, self.input_channels))
        self.skip_connection_list[0] = input_shape
        
        #convolution part
        conv = input_shape
        for i in range(self.conv_num):
            conv = Conv2D(filters = self.filter_num, kernel_size = self.kernel_size, padding = "same")(conv)

            if i % 2 == 1:
                self.skip_connection_list[(i // 2 + 1)] = conv

        #deconvolution part
        for i in range(self.deconv_num - 1):
            conv = Conv2DTranspose(filters = self.filter_num, kernel_size = self.kernel_size, padding = "same")(conv)
            
            #add skip connections
            if i % 2 == 0:
                deconv_skip = Add()([conv, self.skip_connection_list[-1 * (i // 2 + 1)]])
                conv = ReLU()(deconv_skip)

        conv = Conv2DTranspose(filters = self.input_channels, kernel_size = self.kernel_size, padding = "same")(conv)
        #add skip connections
        deconv_skip = Add()([conv, self.skip_connection_list[0]])
        Output = ReLU()(deconv_skip)

        model = Model(inputs = input_shape, outputs = Output)
        model.summary()

        return model

    def RED_Net_even(self):
        """
        if conv_num % 2 == 0
        """
        input_shape = Input((None, None, self.input_channels))
        self.skip_connection_list[0] = input_shape
        
        #convolution part   
        conv = input_shape
        for i in range(self.conv_num):
            conv = Conv2D(filters = self.filter_num, kernel_size = self.kernel_size, padding = "same")(conv)

            if i % 2 == 1 and i != self.conv_num - 1:
                self.skip_connection_list[(i // 2 + 1)] = conv

        #deconvolution part
        for i in range(self.deconv_num - 1):
            conv = Conv2DTranspose(filters = self.filter_num, kernel_size = self.kernel_size, padding = "same")(conv)
            
            #add skip connections
            if i % 2 == 1:
                deconv_skip = Add()([conv, self.skip_connection_list[-1 * (i // 2 + 1)]])
                conv = ReLU()(deconv_skip)

        conv = Conv2DTranspose(filters = self.input_channels, kernel_size = self.kernel_size, padding = "same")(conv)
        #add skip connections
        deconv_skip = Add()([conv, self.skip_connection_list[0]])
        Output = ReLU()(deconv_skip)

        model = Model(inputs = input_shape, outputs = Output)
        model.summary()
        return model

今回は、Convolution layerDeconvolution layerの数に応じてモデルを2つ制作しました。
2つに分けた理由として、Skip connectionの数が変わってくるからです。
例えば、レイヤーの数がどちらも偶数の場合、2レイヤーずつConvolution layerの重み抽出とSkip connectionを行ってしまうと、この2つの処理が被ってしまいます。
簡単に以下の図を作ってみました。
rednet_explain.png

この図の矢印がSkip connectionを表しています。
青色の矢印は問題ないのですが、赤の矢印の処理がおかしくなってしまいます。
そこで、今回の実装ではこの赤い矢印の処理が出ないように、偶数・奇数でモデル分けをしました。

4. 使用したデータセット

今回は、データセットにDIV2K datasetを使用しました。
このデータセットは、単一画像のデータセットで、学習用が800種、検証用とテスト用が100種類ずつのデータセットです。
今回は、学習用データと検証用データを使用しました。
パスの構造はこんな感じです。

train_sharp - 0001.png 
            - 0002.png 
            - ...
            - 0800.png
            
val_sharp   - 0801.png
            - 0802.png 
            - ...
            - 0900.png

このデータをbicubicで縮小したりしてデータセットを生成しました。

5. 画像評価指標PSNR

今回は、画像評価指標としてPSNRを使用しました。
PSNR とは Peak Signal-to-Noise Ratio(ピーク信号対雑音比) の略で、単位はデジベル (dB) で表せます。
PSNR は信号の理論ピーク値と誤差の2乗平均を用いて評価しており、8bit画像の場合、255(最大濃淡値)を誤差の標準偏差で割った値です。
今回は、8bit画像を使用しましたが、計算量を減らすため、全画素値を255で割って使用しました。
そのため、最小濃淡値が0で最大濃淡値が1です。
dB値が高いほど拡大した画像が元画像に近いことを表します。
PSNRの式は以下のとおりです。

PSNR = 10\log_{10} \frac{1^2 * w * h}{\sum_{x=0}^{w-1}\sum_{y=0}^{h-1}(p_1(x,y) - p_2(x,y))^2 }

なお、$w$は画像の幅、$h$は画像の高さを表しており、$p_1$は元画像、$p_2$は PSNRを計測する画像を示しています。

6. コードの使用方法

このコード使用方法は、自分が執筆した別の実装記事とほとんど同じです。

① 学習データ生成

まず、Githubからコードを一式ダウンロードして、カレントディレクトリにします。
Windowsのコマンドでいうとこんな感じ。

C:~/keras_REDNet>

次に、main.pyから生成するデータセットのサイズ・大きさ・切り取る枚数、ファイルのパスなどを指定します。
main.pyの12~21行目です。
使うPCのメモリ数などに応じで、画像サイズや学習データ数の調整が必要です。

main.py
    parser.add_argument('--train_height', type = int, default = 36, help = "Train data size(height)")
    parser.add_argument('--train_width', type = int, default = 36, help = "Train data size(width)")
    parser.add_argument('--test_height', type = int, default = 720, help = "Test data size(height)")
    parser.add_argument('--test_width', type = int, default = 1280, help = "Test data size(width)")
    parser.add_argument('--train_dataset_num', type = int, default = 10000, help = "Number of train datasets to generate")
    parser.add_argument('--test_dataset_num', type = int, default = 5, help = "Number of test datasets to generate")
    parser.add_argument('--train_cut_num', type = int, default = 10, help = "Number of train data to be generated from a single image")
    parser.add_argument('--test_cut_num', type = int, default = 1, help = "Number of test data to be generated from a single image")
    parser.add_argument('--train_path', type = str, default = "../../dataset/reds_train_sharp", help = "The path containing the train image")
    parser.add_argument('--test_path', type = str, default = "../../dataset/reds_val_sharp", help = "The path containing the test image")

指定したら、コマンドでデータセットの生成をします。

C:~/keras_REDNet>python main.py --mode train_datacreate

これで、train_data_list.npzというファイルのデータセットが生成されます。

ついでにテストデータも同じようにコマンドで生成します。コマンドはこれです。

C:~/keras_REDNet>python main.py --mode test_datacreate

② 学習

次に学習を行います。
設定するパラメータの箇所は、epoch数と学習率、今回のモデルの層の数です。
まずは、main.pyの22~25行目。
conv_deconv_numは、Convolution layerDeconvolution layerの数です。
例えば、値を10に設定すると、どちらも10層ずつで合計20層となります。

main.py
    parser.add_argument('--conv_deconv_num', type = int, default = 10, help="Number of convolution and deconvolution nets")
    parser.add_argument('--learning_rate', type = float, default = 1e-4, help = "Learning_rate")
    parser.add_argument('--BATCH_SIZE', type = int, default = 32, help = "Training batch size")
    parser.add_argument('--EPOCHS', type = int, default = 1000, help = "Number of epochs to train for")

後は、学習のパラメータをあれこれ好きな値に設定します。
main.pyの74~90行目のパラメータを調節します。
レイヤーの数に応じて、ifでモデルの選択をしています。

main.py
        if args.conv_deconv_num % 2 == 0:
            train_model = model.RED_Net(args.conv_deconv_num).RED_Net_even()
        elif args.conv_deconv_num % 2 == 1:
            train_model = model.ReD_Net(args.conv_deconv_num).RED_Net_odd()

        optimizers = tf.keras.optimizers.Adam(lr = args.learning_rate)
        train_model.compile(loss = "mean_squared_error",
                        optimizer = optimizers,
                        metrics = [psnr])

        train_model.fit(train_x,
                        train_y,
                        epochs = args.EPOCHS,
                        verbose = 2,
                        batch_size = args.BATCH_SIZE)

        train_model.save("Red_net_model.h5")

optimizerはAdam、損失関数はmean_squared_errorを使用しています。

学習はデータ生成と同じようにコマンドで行います。

C:~/keras_REDNet>python main.py --mode train_model

これで、学習が終わるとモデルが出力されます。
他のモデルの学習を行う場合は、aを変えると同様にできます。

③ 評価

最後にモデルを使用してテストデータで評価を行います。
これも同様にコマンドで行いますが、事前に①でテストデータも生成しておいてください。

C:~/keras_REDNet>python main.py --mode evaluate

このコマンドで、画像を出力してくれます。

7. 結果

出力した画像の結果例を以下に示します。
なお、今回は輝度値のみで学習を行っているため、カラー画像には対応していません。

まず、今回検証に使用した画像は以下の通りです。

4_high.jpg

bicubicとREDNetの各アルゴリズムのPSNRの数値は以下のようになりました。

補間・拡大処理 PSNR(dB)
bicubic 38.52
REDNet 41.88

この画像では、PSNRの値が3ほど向上しました。
入力する画像にもよりそうですが、上の画像のように低周波成分と高周波成分がはっきりしている画像では高解像度化がしやすいのかもしれません。

最後に元画像・低解像度画像・生成画像の一部を並べて拡大表示してみます。
rednet.png

全体的に明るくなってしまってる気もしますが、エッジ部分はしっかりと復元されていそうです。

8. コードの全容

前述の通り、Githubに載せています。
pythonのファイルは主に3つあります。
各ファイルの役割は以下の通りです。

  • data_create.py : データ生成に関するコード。
  • model.py : 超解像のアルゴリズムに関するコード。
  • main.py : 主に使用するコード。

9. 実装環境

以前、バージョンが違って動かないというご質問を頂いたので、記載しておこうと思います。

  • PC環境

    • CPU : AMD Ryzen 5 3500 6-Core Processor
    • メモリ数 : 40GB
    • GPU : NVIDIA GeForce RTX 2060 SUPER
    • OS : Windows 10
  • ライブラリ環境

    • python : 3.7.9
    • tensorflow-gpu : 2.4.1
    • keras : 2.4.3
    • opencv-python : 4.4.0.43

10. まとめ

今回は、単一画像における超解像手法であるREDNetを実装してみました。
これでノイズ除去もできるってすごいですね...

記事が長くなってしまいましたが、最後まで読んでくださりありがとうございました。

参考文献

Image Restoration Using Very Deep Convolutional Encoder-Decoder Networks with Symmetric Skip Connections
 実装の参考にした論文。
画素数の壁を打ち破る 複数画像からの超解像技術
 超解像の説明のために使用。
DIV2K dataset
 今回使用したデータセット。

2
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?