概要
深層学習を用いた、動画像における超解像手法であるVESPCNの実装したので、それのまとめの記事です。
Python + Tensorflowで実装を行いました。
今回紹介するコードはGithubにも載せています。
1. 超解像のおさらい
超解像について簡単に説明をします。
超解像とは解像度の低い画像に対して、解像度を向上させる技術のことです。
ここでいう解像度が低いとは、画素数が少なかったり、高周波成分(輪郭などの鮮鋭な部分を表す)がないような画像のことです。
以下の図で例を示します。(図は[論文]より引用)
これは、超解像の説明をする時によく使われる画像です。
(a)は原画像、(b)は、画素数の少ない画像を見やすいように原画像と同じ大きさにした画像、(c)は、高周波成分を含まない画像の例です。
(b)と(c)は、荒かったりぼやけていたりしていると思います。
このような状態を解像度が低い画像といいます。
そして、超解像はこのような解像度が低い画像に処理を行い、(a)のような精細な画像を出力することを目的としています。
(今回紹介するVESPCNは、これを動画像に応用した超解像手法)
2. 論文の超解像アルゴリズム
超解像のアルゴリズムの概要図は以下の通りです。(図は論文から引用)
この超解像手法は、3枚の連続するフレームを入力して、1枚の高解像度画像を出力します。
前半のMotion estimationが動き補償などのデータ前処理の部分、
後半のSpatio-temporal ESPCNが超解像を行う部分になっています。
それぞれの流れについてもう少し詳しく説明していきます。
① Motion estimation
超解像アルゴリズムに入力する前のデータ前処理を行う箇所になります。
Motion estimation部分のアルゴリズムは以下の図の通りです。(図は論文から引用)
このデータ前処理の部分では、2枚の連続したフレームを入力します。
ここでは、Coarse flow estimationとFine flow estimationという2つのニューラルネットワークを通してデータ前処理を行います。
$\Delta c$と$\Delta f$は移動量を表すパラメータみたいなもので、オプティカルフローのパラメータを求めるのと似ています。
今回は、このデータ前処理でもニューラルネットワークを使用しています。
過去の超解像手法では、オプティカルフローを事前に求めたり、Bicubic法などで補間処理を行っていたため、データ前処理にかなり時間がかかってました。
そのため、データ前処理にもニューラルネットワークを使用することで、計算速度を上げる狙いもあります。
ニューラルネットワークのパラメータは、以下の通りです。(図は論文から引用)
kはフィルターのサイズ、nはフィルターの数、sはストライドの大きさを表しています。
それぞれ6つの層から構成されており、最後にupscaleで拡大しています。
この拡大の方法は、ESPCNで提案されたPixel shuffleを使用します。(図は論文から引用)
特徴としては、拡大したい倍率の2乗の値を入力する点です。
論文の例だと、5層でチャンネル数を32に拡張しています。
4倍拡大なので、4の2乗の16と2チャンネルというように分けることができ、4倍拡大された2チャンネルの結果を返します。
② Spatio-temporal ESPCN
Spatio-temporal ESPCNは、ESPCNを用いて、実際に超解像を行います。
ESPCNは、単一画像にのみ対応しているものでしたが、今回の論文では動画像に対応できるように入力チャンネル数を変更したりしています。
ESPCNの構造は前でも載せましたが、このようになっています。(図は論文から引用)
ESPCNについては、以前に実装をしているので興味があればそちらもご覧ください。(実装記事)
ESPCNで処理を行えば、結果が出力され本モデルの全ての工程が終わります。
3. 実装したアルゴリズム
今回は、VESPCNの全体を1つのニューラルネットワークのモデルとして実装してみました。
下の図が、本モデルの概要に近いと思います。(図は論文から引用)
入力する低解像度画像は3枚で、まずMotion estimationでデータ前処理を行います。
その後、ESPCNに3枚の画像を入力するという流れです。
コマンドラインでモデルを出力した例は、長すぎるので最後に参考として載せています。
4. 使用したデータセット
今回は、データセットにREDSを使用しました。
このデータセットは、動画像の超解像用のデータセットで、240種類の学習用データ、30種類の検証用データ、30種類のテスト用データの計300種類のデータセットです。
別の実装でもほとんどこのデータセットを使用しています。
パスの構造はこんな感じです。
train_sharp - 001 - フレーム100枚
- 002 - フレーム100枚
- ...
val_sharp - 001 - フレーム100枚
- 002 - フレーム100枚
- ...
このデータをBicubicで縮小したりしてデータセットを生成しました。
5. 画像評価指標PSNR
今回は、画像評価指標としてPSNRを使用しました。
PSNR とは Peak Signal-to-Noise Ratio(ピーク信号対雑音比) の略で、単位はデジベル (dB) で表せます。
PSNR は信号の理論ピーク値と誤差の2乗平均を用いて評価しており、8bit画像の場合、255(最大濃淡値)を誤差の標準偏差で割った値です。
今回は、8bit画像を使用しましたが、計算量を減らすため、全画素値を255で割って使用しました。
そのため、最小濃淡値が0で最大濃淡値が1です。
dB値が高いほど拡大した画像が元画像に近いことを表します。
PSNRの式は以下のとおりです。
PSNR = 10\log_{10} \frac{1^2 * w * h}{\sum_{x=0}^{w-1}\sum_{y=0}^{h-1}(p_1(x,y) - p_2(x,y))^2 }
なお、$w$は画像の幅、$h$は画像の高さを表しており、$p_1$は元画像、$p_2$はPSNRを計測する画像を示しています。
6. コードの使用方法
このコード使用方法は、自分が執筆した別の実装記事とほとんど同じです。
① 学習データ生成
まず、Githubからコードを一式ダウンロードして、カレントディレクトリにします。
Windowsのコマンドでいうとこんな感じ。
C:~/keras_RVSR>
次に、main.pyから生成するデータセットのサイズ・大きさ・切り取る枚数、ファイルのパスなどを指定します。
main.pyの15~26行目です。
使うPCのメモリ数などに応じで、画像サイズや学習データ数の調整が必要です。
train_height = 160 #HRのサイズ
train_width = 160
test_height = 720
test_width = 1280
train_dataset_num = 20000 #生成する学習用データの数
test_dataset_num = 10 #生成するテスト用データの数
train_cut_num = 10 #一組のフレームから生成するデータの数
test_cut_num = 1
train_movie_path = "../../reds/train_sharp" #動画像のフレームが入っているパス
test_movie_path = "../../reds/val_sharp"
指定したら、コマンドでデータセットの生成をします。
C:~/keras_RVSR>python main.py --mode train_datacreate
これで、train_data_list.npzというファイルのデータセットが生成されます。
ついでにテストデータも同じようにコマンドで生成します。コマンドはこれです。
C:~/keras_RVSR>python main.py --mode test_datacreate
② 学習
次に学習を行います。
設定するパラメータの箇所は、epoch数と学習率とかですかね...
まずは、main.pyの28~33行目
input_LR_num = 3 #入力するLRの数
input_channels = 1 #入力するLRのチャンネル数
mag = 4 #拡大倍率
MAX_BATSH_SIZE = 128 #最大のバッチサイズ
EPOCHS = 1000
後は、学習のパラメータをあれこれ好きな値に設定します。85~107行目です。
今回は、バッチサイズの開始が1で、10epochごとに2倍にしていく、という構造だったため、少しコードを変更しています。
具体的にはMAX_BATCH_SIZEを2で割り続けて、2を何乗すればいいか計算し、その回数に応じて10回で学習させています。
optimizers = tf.keras.optimizers.Adam(learning_rate=1e-4)
train_model.compile(loss = "mean_squared_error",
optimizer = optimizers,
metrics = [psnr])
x = MAX_BATSH_SIZE
i = 0
while x > 1:
x /= 2
i += 1
for n in range(i):
train_model.fit({"input_t_minus_1":train_x[0], "input_t":train_x[1], "input_t_plus_1":train_x[2]},
train_y,
epochs = 10,
verbose = 2,
batch_size = 2 ** n)
train_model.fit({"input_t_minus_1":train_x[0], "input_t":train_x[1], "input_t_plus_1":train_x[2]},
train_y,
epochs = EPOCHS - 10 * i,
verbose = 2,
batch_size = MAX_BATSH_SIZE)
optimizerはAdam、損失関数は最小二乗法を使用しています。
入力画像は今回は3枚で出力は1枚です。
学習はデータ生成と同じようにコマンドで行います。
C:~/keras_RVSR>python main.py --mode train_model
これで、学習が終わるとモデルが出力されます。
③ 評価
最後にモデルを使用してテストデータで評価を行います。
これも同様にコマンドで行いますが、事前に①でテストデータも生成しておいてください。
C:~/keras_RVSR>python main.py --mode evaluate
このコマンドで、画像を出力してくれます。
7. 結果
出力した画像はこのようになりました。
なお、今回は輝度値のみで学習を行っているため、カラー画像には対応していません。
対応させる場合は、modelのInputのchannel数を変えたり、データセット生成のchannel数を変える必要があります。
分かりにくいので、低解像度画像を拡大にして生成画像と同じサイズにしたものも載せておきます。
だいぶ粗さが取れているのが目視でもわかります。
最後に元画像・低解像度画像・生成画像の一部を並べて表示してみます。
並べてみると、高解像度化がされているのが分かります。
実際の動画像に処理をかける場合は、動画像をフレームに分解して、1枚ずつ処理を行う必要があります。
OpenCVで動画像をフレームごとに取得するとOKです。
8. コードの全容
前述の通り、Githubに載せています。
pythonのファイルは主に3つあります。
各ファイルの役割は以下の通りです。
- data_create.py : データ生成に関するコード。
- model.py : 超解像のアルゴリズムに関するコード。
- main.py : 主に使用するコード。
9. まとめ
今回は、最近読んだ論文のVESPCNを元に実装してみました。
モデルが複雑すぎていつもより実装に時間がかかりましたが、無事に実装できました。
次は単一画像の超解像手法の実装をしてみようと思います。
記事が長くなってしまいましたが、最後まで読んでくださりありがとうございました。
参考文献
・Real-Time Video Super-Resolution with Spatio-Temporal Networks and Motion
Compensation
今回実装の参考にした論文。
・画素数の壁を打ち破る 複数画像からの超解像技術
超解像の説明のために使用。
・REDSのデータセット
今回使用したデータセット。
[参考]本モデルのコマンドライン出力
参考として、上記で軽く触れた本モデルのコマンドラインを紹介します。
(各フレームのMotion estimationなども全て1つのモデルにまとめているので入り乱れています。)
結構長めのモデルになってしまいました。
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_t_minus_1 (InputLayer) [(None, None, None, 0
__________________________________________________________________________________________________
input_t (InputLayer) [(None, None, None, 0
__________________________________________________________________________________________________
input_t_plus_1 (InputLayer) [(None, None, None, 0
__________________________________________________________________________________________________
concatenate (Concatenate) (None, None, None, 2 0 input_t[0][0]
input_t_minus_1[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, None, None, 2 0 input_t[0][0]
input_t_plus_1[0][0]
__________________________________________________________________________________________________
conv2d (Conv2D) (None, None, None, 2 1224 concatenate[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D) (None, None, None, 2 1224 concatenate_2[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, None, None, 2 5208 conv2d[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D) (None, None, None, 2 5208 conv2d_10[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, None, None, 2 14424 conv2d_1[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D) (None, None, None, 2 14424 conv2d_11[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, None, None, 2 5208 conv2d_2[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D) (None, None, None, 2 5208 conv2d_12[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, None, None, 3 6944 conv2d_3[0][0]
__________________________________________________________________________________________________
conv2d_14 (Conv2D) (None, None, None, 3 6944 conv2d_13[0][0]
__________________________________________________________________________________________________
lambda (Lambda) (None, None, None, 2 0 conv2d_4[0][0]
__________________________________________________________________________________________________
lambda_2 (Lambda) (None, None, None, 2 0 conv2d_14[0][0]
__________________________________________________________________________________________________
multiply (Multiply) (None, None, None, 2 0 input_t_minus_1[0][0]
lambda[0][0]
__________________________________________________________________________________________________
multiply_2 (Multiply) (None, None, None, 2 0 input_t_plus_1[0][0]
lambda_2[0][0]
__________________________________________________________________________________________________
tf.__operators__.getitem (Slici (None, None, None) 0 multiply[0][0]
__________________________________________________________________________________________________
tf.__operators__.getitem_1 (Sli (None, None, None) 0 multiply[0][0]
__________________________________________________________________________________________________
tf.__operators__.getitem_4 (Sli (None, None, None) 0 multiply_2[0][0]
__________________________________________________________________________________________________
tf.__operators__.getitem_5 (Sli (None, None, None) 0 multiply_2[0][0]
__________________________________________________________________________________________________
tf.expand_dims (TFOpLambda) (None, None, None, 1 0 tf.__operators__.getitem[0][0]
__________________________________________________________________________________________________
tf.expand_dims_1 (TFOpLambda) (None, None, None, 1 0 tf.__operators__.getitem_1[0][0]
__________________________________________________________________________________________________
tf.expand_dims_4 (TFOpLambda) (None, None, None, 1 0 tf.__operators__.getitem_4[0][0]
__________________________________________________________________________________________________
tf.expand_dims_5 (TFOpLambda) (None, None, None, 1 0 tf.__operators__.getitem_5[0][0]
__________________________________________________________________________________________________
add (Add) (None, None, None, 1 0 tf.expand_dims[0][0]
tf.expand_dims_1[0][0]
__________________________________________________________________________________________________
add_5 (Add) (None, None, None, 1 0 tf.expand_dims_4[0][0]
tf.expand_dims_5[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, None, None, 1 0 input_t_minus_1[0][0]
add[0][0]
__________________________________________________________________________________________________
add_6 (Add) (None, None, None, 1 0 input_t_plus_1[0][0]
add_5[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, None, None, 5 0 input_t[0][0]
input_t_minus_1[0][0]
lambda[0][0]
add_1[0][0]
__________________________________________________________________________________________________
concatenate_3 (Concatenate) (None, None, None, 5 0 input_t[0][0]
input_t_plus_1[0][0]
lambda_2[0][0]
add_6[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, None, None, 2 3024 concatenate_1[0][0]
__________________________________________________________________________________________________
conv2d_15 (Conv2D) (None, None, None, 2 3024 concatenate_3[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, None, None, 2 5208 conv2d_5[0][0]
__________________________________________________________________________________________________
conv2d_16 (Conv2D) (None, None, None, 2 5208 conv2d_15[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, None, None, 2 5208 conv2d_6[0][0]
__________________________________________________________________________________________________
conv2d_17 (Conv2D) (None, None, None, 2 5208 conv2d_16[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D) (None, None, None, 2 5208 conv2d_7[0][0]
__________________________________________________________________________________________________
conv2d_18 (Conv2D) (None, None, None, 2 5208 conv2d_17[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, None, None, 8 1736 conv2d_8[0][0]
__________________________________________________________________________________________________
conv2d_19 (Conv2D) (None, None, None, 8 1736 conv2d_18[0][0]
__________________________________________________________________________________________________
lambda_1 (Lambda) (None, None, None, 2 0 conv2d_9[0][0]
__________________________________________________________________________________________________
lambda_3 (Lambda) (None, None, None, 2 0 conv2d_19[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, None, None, 2 0 lambda[0][0]
lambda_1[0][0]
__________________________________________________________________________________________________
add_7 (Add) (None, None, None, 2 0 lambda_2[0][0]
lambda_3[0][0]
__________________________________________________________________________________________________
multiply_1 (Multiply) (None, None, None, 2 0 input_t_minus_1[0][0]
add_2[0][0]
__________________________________________________________________________________________________
multiply_3 (Multiply) (None, None, None, 2 0 input_t_plus_1[0][0]
add_7[0][0]
__________________________________________________________________________________________________
tf.__operators__.getitem_2 (Sli (None, None, None) 0 multiply_1[0][0]
__________________________________________________________________________________________________
tf.__operators__.getitem_3 (Sli (None, None, None) 0 multiply_1[0][0]
__________________________________________________________________________________________________
tf.__operators__.getitem_6 (Sli (None, None, None) 0 multiply_3[0][0]
__________________________________________________________________________________________________
tf.__operators__.getitem_7 (Sli (None, None, None) 0 multiply_3[0][0]
__________________________________________________________________________________________________
tf.expand_dims_2 (TFOpLambda) (None, None, None, 1 0 tf.__operators__.getitem_2[0][0]
__________________________________________________________________________________________________
tf.expand_dims_3 (TFOpLambda) (None, None, None, 1 0 tf.__operators__.getitem_3[0][0]
__________________________________________________________________________________________________
tf.expand_dims_6 (TFOpLambda) (None, None, None, 1 0 tf.__operators__.getitem_6[0][0]
__________________________________________________________________________________________________
tf.expand_dims_7 (TFOpLambda) (None, None, None, 1 0 tf.__operators__.getitem_7[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, None, None, 1 0 tf.expand_dims_2[0][0]
tf.expand_dims_3[0][0]
__________________________________________________________________________________________________
add_8 (Add) (None, None, None, 1 0 tf.expand_dims_6[0][0]
tf.expand_dims_7[0][0]
__________________________________________________________________________________________________
add_4 (Add) (None, None, None, 1 0 input_t_minus_1[0][0]
add_3[0][0]
__________________________________________________________________________________________________
add_9 (Add) (None, None, None, 1 0 input_t_plus_1[0][0]
add_8[0][0]
__________________________________________________________________________________________________
concatenate_4 (Concatenate) (None, None, None, 3 0 add_4[0][0]
input_t[0][0]
add_9[0][0]
__________________________________________________________________________________________________
conv2d_20 (Conv2D) (None, None, None, 9 684 concatenate_4[0][0]
__________________________________________________________________________________________________
conv2d_21 (Conv2D) (None, None, None, 3 2624 conv2d_20[0][0]
__________________________________________________________________________________________________
conv2d_22 (Conv2D) (None, None, None, 1 4624 conv2d_21[0][0]
__________________________________________________________________________________________________
lambda_4 (Lambda) (None, None, None, 1 0 conv2d_22[0][0]
==================================================================================================
Total params: 114,716
Trainable params: 114,716
Non-trainable params: 0
__________________________________________________________________________________________________