LoginSignup
2
6

More than 1 year has passed since last update.

DeepLearningを用いた超解像手法/RVSRを参考にした実装

Last updated at Posted at 2021-04-29

概要

深層学習を用いた、動画像における超解像手法であるRVSRを参考に実装したので、それのまとめの記事です。
Python + Tensorflowで実装を行いました。

論文をまとめた記事もあるのでそちらも是非!
【論文メモ】超解像手法/RVSRの論文まとめ

今回紹介するコードはGithubにも載せています。

1. 超解像のおさらい

超解像について簡単に説明をします。

超解像とは解像度の低い画像に対して、解像度を向上させる技術のことです。
ここでいう解像度が低いとは、画素数が少なかったり、高周波成分(輪郭などの鮮鋭な部分を表す)がないような画像のことです。
以下の図で例を示します。(図は[論文]より引用)
image.png
これは、超解像の説明をする時によく使われる画像です。

(a)は原画像、(b)は、画素数の少ない画像を見やすいように原画像と同じ大きさにした画像、(c)は、高周波成分を含まない画像の例です。
(b)と(c)は、荒かったりぼやけていたりしていると思います。
このような状態を解像度が低い画像といいます。

そして、超解像はこのような解像度が低い画像に処理を行い、(a)のような精細な画像を出力することを目的としています。
(今回紹介するRVSRは、これを動画像に応用した超解像手法)

2. 論文の超解像アルゴリズム

今回実装したアルゴリズムを紹介する前に、論文のアルゴリズムの紹介をします。

超解像のアルゴリズムの概要図は以下の通りです。(図は論文から引用)
image.png

この超解像アルゴリズムは主に3つのパートに分かれています。

  • SR inference branch:高解像度画像の候補を出力。
  • Temporal modulation branch:weight mapの出力。
  • Temporal aggregation:最終的な結果の出力。

それぞれの流れについてもう少し詳しく説明していきます。

① SR inference branch

高解像度画像の候補を出力するパートです。

入力するフレーム数に応じて、複数のbranchを生成します。
入力フレーム数を $N$ 枚とすると、$2x - 1 = N$ となる $x$ の数だけbranchを生成します。
例えば、入力画像が5枚だと、branchの数は3つです。
ただし、入力フレームは必ず奇数になるようにします。

ここはニューラルネットワークを用いて高解像度化させています。
モデルはESPCNを使用しています。
入力フレームが複数なので、それに対応できるように少し調整をしていますが、ほとんど同じです。

ESPCNについては、以前実装記事 (超解像手法/ESPCNの実装) を書いていますのでよければそちらもご覧ください。

② Temporal modulation branch

weight mapを出力するパートです。
weight mapとは、各画素における重みを集約したものとなります。(いい説明思いつかないですね...)

ここの構造もニューラルネットワークを用います。
ESPCNと似たモデルを用いると書いていますが、詳細は書かれていなかったような気がします。

weight mapは、SR inference branchの数だけ出力します。
つまり、高解像度化された画像のそれぞれにこのweight mapを適応させるということです。

③ Temporal aggregation

SR inference branchで生成した高解像度画像の候補と、Temporal modulation branchで生成したweight mapをそれぞれ乗算します。
最後に乗算するので、Temporal modulation branchのweight mapの数はbranchの数と同じにしたということです。

最後に、乗算した結果を全て足し合わせたものを最終的な結果とします。

今回実装したアルゴリズムは、論文で紹介されている超解像アルゴリズムを一部変形させたものです。
深層学習のアルゴリズムが一部違いますが、大体同じとなっています。

3. 実装したアルゴリズム

今回実装したアルゴリズムは、高解像度画像の候補を出力するSR inference branchを中心としたものです。
下の図の赤枠の箇所です。
image.png

Temporal modulation branchのモデルは、ESPCNと似た構造とは書いていたのですが、詳細な構造が記載されていなかったので、そこは省きました。
Temporal modulation branchを実装してないと、Temporal aggregationもできません。
そこで、今回は生成した高解像度画像の候補の平均をとることで、その代わりとしました。
(論文では、平均を取って結果を出力した場合の結果もあったので、ある意味では論文と同じかもしれません。)

入力する低解像度画像は5枚としたので、SR branchの数は3つです。

コマンドラインでモデルを見ると以下の通りになります。
ここでは、見やすいようにSR branchごとに整列させて表示をしています。
(実際は1つのモデルにまとめているので入り乱れています。)

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_0 (InputLayer)            [(None, None, None,  0
_______________________________________________________________________________________________
input_1 (InputLayer)            [(None, None, None,  0
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, None, None,  0
_________________________________________________________________________________________________
input_3 (InputLayer)            [(None, None, None,  0
__________________________________________________________________________________________________
input_4 (InputLayer)            [(None, None, None,  0
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, None, None, 1 26          input_2[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, None, None, 3 320         conv2d[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, None, None, 1 4624        conv2d_1[0][0]
__________________________________________________________________________________________________
lambda (Lambda)                 (None, None, None, 1 0           conv2d_2[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, None, None, 3 0           input_1[0][0]
                                                                 input_2[0][0]
                                                                 input_3[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, None, None, 3 228         concatenate[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, None, None, 3 896         conv2d_3[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, None, None, 1 4624        conv2d_4[0][0]
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, None, None, 1 0           conv2d_5[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, None, None, 5 0           input_0[0][0]
                                                                 input_1[0][0]
                                                                 input_2[0][0]
                                                                 input_3[0][0]
                                                                 input_4[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, None, None, 5 630         concatenate_1[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, None, None, 3 1472        conv2d_6[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, None, None, 1 4624        conv2d_7[0][0]
__________________________________________________________________________________________________
lambda_2 (Lambda)               (None, None, None, 1 0           conv2d_8[0][0]
__________________________________________________________________________________________________
average (Average)               (None, None, None, 1 0           lambda[0][0]
                                                                 lambda_1[0][0]
                                                                 lambda_2[0][0]
==================================================================================================
Total params: 17,444
Trainable params: 17,444
Non-trainable params: 0
__________________________________________________________________________________________________

入力フレームは5枚で、Input_Layerとして入力しています。
2つめ以降のbranchでは、入力フレームが複数になるので、Concatenateを使用して結合してから入力しています。
ニューラルネットワークはESPCNと同じ構造です。(実装記事)

4. 論文との相違点

前述の通り、超解像アルゴリズムの一部を変更しています。

また、学習データやテスト用データの生成方法を簡略化しています。
論文では、超解像処理を行う前に、ニューラルネットワークやオプティカルフローを計算したりしてデータ前処理を行っています。
しかし、今回は超解像アルゴリズムに焦点を当てて実装を行なったため、そちらは行いませんでした。
(データ前処理は、【論文メモ】超解像手法/RVSRの論文まとめで触れていますので、気になる方はそちらをご覧ください。本記事では説明を省きます。)

その代わり、Bicubic法で縮小したりしてデータセットを生成しています。

5. 使用したデータセット

今回は、データセットにREDSを使用しました。
このデータセットは、動画像の超解像用のデータセットで、240種類の学習用データ、30種類の検証用データ、30種類のテスト用データの計300種類のデータセットです。
別の実装でもほとんどこのデータセットを使用しています。

パスの構造はこんな感じです。

train_sharp - 001 - フレーム100枚
            - 002 - フレーム100枚
            - ...

val_sharp   - 001 - フレーム100枚
            - 002 - フレーム100枚
            - ...

このデータをBicubicで縮小したりしてデータセットを生成しました。

6. 画像評価指標PSNR

今回は、画像評価指標としてPSNRを使用しました。
PSNR とは Peak Signal-to-Noise Ratio(ピーク信号対雑音比) の略で、単位はデジベル (dB) で表せます。
PSNR は信号の理論ピーク値と誤差の2乗平均を用いて評価しており、8bit画像の場合、255(最大濃淡値)を誤差の標準偏差で割った値です。
今回は、8bit画像を使用しましたが、計算量を減らすため、全画素値を255で割って使用しました。
そのため、最小濃淡値が0で最大濃淡値が1です。
dB値が高いほど拡大した画像が元画像に近いことを表します。
PSNRの式は以下のとおりです。

PSNR = 10\log_{10} \frac{1^2 * w * h}{\sum_{x=0}^{w-1}\sum_{y=0}^{h-1}(p_1(x,y) - p_2(x,y))^2 }

なお、$w$は画像の幅、$h$は画像の高さを表しており、$p_1$は元画像、$p_2$はPSNRを計測する画像を示しています。

7. コードの使用方法

このコード使用方法は、自分が執筆した別の実装記事とほとんど同じです。

① 学習データ生成

まず、Githubからコードを一式ダウンロードして、カレントディレクトリにします。
Windowsのコマンドでいうとこんな感じ。

C:~/keras_RVSR>

次に、main.pyから生成するデータセットのサイズ・大きさ・切り取る枚数、ファイルのパスなどを指定します。
main.pyの15~26行目です。
使うPCのメモリ数などに応じで、画像サイズや学習データ数の調整が必要です。

main.py
    train_height = 120 #HRのサイズ
    train_width = 120
    test_height = 720  #HRのサイズ
    test_width = 1280

    train_dataset_num = 30000 #生成する学習データの数
    test_dataset_num = 10     #生成するテストデータの数
    train_cut_num = 10        #一組の動画から生成するデータの数
    test_cut_num = 1

    train_movie_path = "../../reds/train_sharp"  #動画のフレームが入っているパス
    test_movie_path = "../../reds/val_sharp"

指定したら、コマンドでデータセットの生成をします。

C:~/keras_RVSR>python main.py --mode train_datacreate

これで、train_data_list.npzというファイルのデータセットが生成されます。

ついでにテストデータも同じようにコマンドで生成します。コマンドはこれです。

C:~/keras_RVSR>python main.py --mode test_datacreate

② 学習

次に学習を行います。
設定するパラメータの箇所は、epoch数と学習率とかですかね...
まずは、main.pyの28~33行目

main.py
    input_LR_num = 5     #入力するLRフレーム数
    input_channels = 1   #入力するLRフレームのチャンネル数
    mag = 4              #拡大倍率

    BATSH_SIZE = 64
    EPOCHS = 3000

後は、学習のパラメータをあれこれ好きな値に設定します。85~94行目です。

main.py
        optimizers = tf.keras.optimizers.Adam(learning_rate=1e-4)
        train_model.compile(loss = "mean_squared_error",
                        optimizer = optimizers,
                        metrics = [psnr])

        train_model.fit({"input_0":train_x[0], "input_1":train_x[1], "input_2":train_x[2], "input_3":train_x[3], "input_4":train_x[4]},
                    train_y,
                    epochs = EPOCHS,
                    verbose = 2,
                    batch_size = BATSH_SIZE)

optimizerはAdam、損失関数は最小二乗法を使用しています。
入力画像は今回は5枚で出力は1枚です。

学習はデータ生成と同じようにコマンドで行います。

C:~/keras_RVSR>python main.py --mode train_model

これで、学習が終わるとモデルが出力されます。

③ 評価

最後にモデルを使用してテストデータで評価を行います。
これも同様にコマンドで行いますが、事前に①でテストデータも生成しておいてください。

C:~/keras_RVSR>python main.py --mode evaluate

このコマンドで、画像を出力してくれます。

8. 結果

出力した画像はこのようになりました。
なお、今回は輝度値のみで学習を行っているため、カラー画像には対応していません。
対応させる場合は、modelのInputのchannel数を変えたり、データセット生成のchannel数を変える必要があります。

元画像
5_high.jpg

低解像度画像(4倍縮小)
5_low.jpg

生成画像
PSNR:28.21
5_pred.jpg

分かりにくいので、低解像度画像を拡大にして生成画像と同じサイズにしたものも載せておきます。

低解像度画像(生成画像と同じサイズに拡大)
5_low_4.jpg

かなり、粗さが取れているのが分かります。
流石に4倍拡大だと、あっと驚くような精細な画像は出力されません。

最後に元画像・低解像度画像・生成画像の一部を並べて表示してみます。
rvsr.png

並べてみても分かる通り、高解像度化はちゃんとされていそうです。
従って、拡大を含めた超解像がしっかり行われていることが確認できます。

とはいえ、元画像と比べるとまだまだな部分があります。
4倍拡大だと、補う情報量が多くなるので、求めるような超解像はやはり難しいと思います。

もし、精細な画像を得るのであれば、拡大倍率を2倍とかでしてみるといいと思います。
また、4倍のままでより精細な画像を得る場合は、モデルのパラメータチューニングをしたり、学習データを増やしたり学習回数を増やしたりしてみるといいかもしれません。

また、実際の動画像に処理をかける場合は、動画像をフレームに分解して、1枚ずつ処理を行う必要があります。
OpenCVで動画像をフレームごとに取得して、って感じですかね。

9. コードの全容

前述の通り、Githubに載せています。
pythonのファイルは主に3つあります。
各ファイルの役割は以下の通りです。

  • data_create.py : データ生成に関するコード。
  • model.py : 超解像のアルゴリズムに関するコード。
  • main.py : 主に使用するコード。

10. まとめ

今回は、最近読んだ論文のRVSRを元に実装してみました。

だいぶモデルは複雑ですが、その分超解像がしっかりと行われていることが分かりました。

記事が長くなってしまいましたが、最後まで読んでくださりありがとうございました。
次は別のアルゴリズムの実装をしてみるつもりです。

参考文献

Robust Video Super-Resolution with Learned Temporal Dynamics
 今回実装の参考にした論文。
画素数の壁を打ち破る 複数画像からの超解像技術
 超解像の説明のために使用。
REDSのデータセット
 今回使用したデータセット。

2
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
6