LoginSignup
7
5

More than 3 years have passed since last update.

超解像技術-SRGAN-実装してみた(Tensorflow 2.0) 推論フェーズ編

Last updated at Posted at 2020-12-05

こちらはITRC Advent Calendar 2020の6日目の記事です。
前の記事: @bluekey0725 さんのNetworkManagerで802.1xプロトコルを使用したwifiのホットスポットに機器を繋げたい
次の記事: @koseiinfratopさんの某アーティストグループは何を僕らに伝えたかったのか??? (WordCloudで歌詞化してみた)

概要

前回の続きで、SRGANの推論フェーズ編になります。

環境

-Software-
Windows 10 Home
Anaconda3 64-bit(Python3.7)
VSCode
-Library-
Tensorflow 2.2.0
opencv-python 4.1.2.30
-Hardware-
CPU: Intel core i9 9900K
GPU: NVIDIA GeForce RTX2080ti
RAM: 16GB 3200MHz

プログラム

Githubに上げておきます。
https://github.com/himazin331/Super-resolution-GAN
リポジトリには訓練フェーズ、推論フェーズ、データセット(General-100)が含まれています。

ソースコード

ソースコードだけ載せておきます。

srgan_pre.py
import argparse as arg
import os
import sys

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf
import tensorflow.keras.layers as kl

import cv2
from PIL import Image
import matplotlib.pyplot as plt

import numpy as np

# Super-resolution Image Generator
class Generator(tf.keras.Model):
    def __init__(self, input_shape):
        super().__init__()

        input_shape_ps = (input_shape[0], input_shape[1], 64) 

        # Pre stage(Down Sampling)
        self.pre = [
            kl.Conv2D(64, kernel_size=9, strides=1,
                    padding="same", input_shape=input_shape),
            kl.Activation(tf.nn.relu)
        ]

        # Residual Block
        self.res = [
            [
                Res_block(64, input_shape) for _ in range(7)
            ]
        ]

        # Middle stage
        self.middle = [
            kl.Conv2D(64, kernel_size=3, strides=1, padding="same"),
            kl.BatchNormalization()
        ]

        # Pixel Shuffle(Up Sampling)
        self.ps =[
            [
                Pixel_shuffler(128, input_shape_ps) for _ in range(2)
            ],
            kl.Conv2D(3, kernel_size=9, strides=4, padding="same", activation="tanh")
        ]

    def call(self, x):

        # Pre stage
        pre = x
        for layer in self.pre:
            pre = layer(pre)

        # Residual Block
        res = pre
        for layer in self.res:
            for l in layer:
                res = l(res)

        # Middle stage
        middle = res
        for layer in self.middle:
            middle = layer(middle)
        middle += pre

        # Pixel Shuffle
        out = middle
        for layer in self.ps:
            if isinstance(layer, list):
                for l in layer:
                    out = l(out)
            else:
                out = layer(out)

        return out

# Pixel Shuffle
class Pixel_shuffler(tf.keras.Model):
    def __init__(self, out_ch, input_shape):
        super().__init__()

        self.conv = kl.Conv2D(out_ch, kernel_size=3, strides=1,
                            padding="same", input_shape=input_shape)
        self.act = kl.Activation(tf.nn.relu)

    # forward proc
    def call(self, x):

        d1 = self.conv(x)
        d2 = self.act(tf.nn.depth_to_space(d1, 2))

        return d2

# Residual Block
class Res_block(tf.keras.Model):
    def __init__(self, ch, input_shape):
        super().__init__()

        self.conv1 = kl.Conv2D(ch, kernel_size=3, strides=1,
                            padding="same", input_shape=input_shape)
        self.bn1 = kl.BatchNormalization()
        self.av1 = kl.Activation(tf.nn.relu)

        self.conv2 = kl.Conv2D(ch, kernel_size=3, strides=1,
                            padding="same")
        self.bn2 = kl.BatchNormalization()

        self.add = kl.Add()

    def call(self, x):

        d1 = self.av1(self.bn1(self.conv1(x)))
        d2 = self.bn2(self.conv2(d1))

        return self.add([x, d2])

def main():

    # Command line option
    parser = arg.ArgumentParser(description='Super-resolution GAN prediction')
    parser.add_argument('--param', '-p', type=str, default=None,
                        help='学習済みパラメータの指定(未指定ならエラー)')    
    parser.add_argument('--data_img', '-d', type=str, default=None,
                        help='画像ファイルの指定(未指定ならエラー)')
    parser.add_argument('--out', '-o', type=str, default=os.path.join(os.path.dirname(os.path.abspath(__file__)), "result"),
                        help='保存先指定(デフォルト値=./result)')                        
    parser.add_argument('--he', '-he', type=int, default=128,
                        help='リサイズの高さ指定(デフォルト値=128)')      
    parser.add_argument('--wi', '-wi', type=int, default=128,
                        help='リサイズの指定(デフォルト値=128)')
    parser.add_argument('--mag', '-m', type=int, default=2,
                        help='縮小倍率の指定(デフォルト値=2)')                           
    args = parser.parse_args()

    # Parameter-File not specified. -> Exception
    if args.param == None:
        print("\nException: Trained Parameter-File not specified.\n")
        sys.exit()
    # An Parameter-File that does not exist was specified. -> Exception
    if os.path.exists(args.param) != True:
        print("\nException: Trained Parameter-File {} is not found.\n".format(args.param))
        sys.exit()
    # Image not specified. -> Exception
    if args.data_img == None:
        print("\nException: Image not specified.\n")
        sys.exit()
    # An image that does not exist was specified. -> Exception
    if os.path.exists(args.data_img) != True:
        print("\nException: Image {} is not found.\n".format(args.data_img))
        sys.exit()
    # When 0 is entered for either width/height or Reduction ratio. -> Exception
    if args.he == 0 or args.wi == 0 or args.mag == 0:
        print("\nException: Invalid value has been entered.\n")
        sys.exit()

    # Setting info
    print("=== Setting information ===")
    print("# Trained Prameter-File: {}".format(os.path.abspath(args.param)))
    print("# Image: {}".format(args.data_img))
    print("# Output folder: {}".format(args.out))
    print("")
    print("# Height: {}".format(args.he))
    print("# Width: {}".format(args.wi))
    print("# Magnification: {}".format(args.mag))
    print("===========================")

    # Create output folder (If the folder exists, it will not be created.)
    os.makedirs(args.out, exist_ok=True)

    # Network Setup
    model = Generator(input_shape=(args.he, args.wi, 3))
    model.build((None, args.he, args.wi, 3))
    model.load_weights(args.param)

    # High-resolutin Image
    img = cv2.imread(args.data_img)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    hr_img = cv2.resize(img, (args.he, args.wi))

    # Low-resolution Image
    lr_img = cv2.resize(hr_img, (int(args.he/args.mag), int(args.wi/args.mag)))
    lr_img = cv2.resize(lr_img, (args.he, args.wi))
    lr_img_s = lr_img

    # Image processing
    lr_img = tf.convert_to_tensor(lr_img, np.float32) 
    lr_img = tf.convert_to_tensor(lr_img, np.float32) 
    lr_img = (lr_img - 127.5) / 127.5
    lr_img = lr_img[np.newaxis, :, :, :]

    # Super-resolution
    re = model.predict(lr_img)

    # Super-resolution Image processing
    re = np.reshape(re, (args.he, args.wi, 3))
    re = re * 127.5 + 127.5
    re = np.clip(re, 0.0, 255.0)

    # Low-resolution Image output
    lr_img = Image.fromarray(np.uint8(lr_img_s))
    lr_img.show()
    lr_img.save(os.path.join(args.out, "Low-resolution Image(SRGAN).bmp"))

    # Super-resolution Image output
    sr_img = Image.fromarray(np.uint8(re))
    sr_img.show()
    sr_img.save(os.path.join(args.out, "Super-resolution Image(SRGAN).bmp"))

    # High-resolution Image output
    hr_img = Image.fromarray(np.uint8(hr_img))
    hr_img.show()
    hr_img.save(os.path.join(args.out, "High-resolution Image(SRGAN).bmp"))

if __name__ == "__main__":
    main()

実行結果

HR - 高解像画像(オリジナル)
SR - 超解像画像
LR - 低解像画像

128x128 倍率 2 - Epoch 1000
128x128 倍率 2 - Epoch 1000

128x128 倍率 4 - Epoch 1000
128x128 倍率 4 - Epoch 1000

128x128 倍率 4 - Epoch 3000
128x128 倍率 4 - Epoch 3000

256x256 倍率 2 - Epoch 1000
256x256 倍率 2 - Epoch 1000

256x256 倍率 4 - Epoch 1000
256x256 倍率 4 - Epoch 1000

256x256 倍率 4 - Epoch 3000
256x256 倍率 4 - Epoch 3000

256x256 倍率 4 - Epoch 5000
256x256 倍率 4 - Epoch 5000

256x256ぐらいの解像度になるとメモリ不足に陥ったので、ミニバッチサイズを下げて学習しました。
Epoch数を増やせば綺麗なresultがでることが確認できました。

おわりに

SRGANはSRCNNよりも高いresultを出してくれます。が、前回でも触れたようにメモリリソースを多く食うものとなっているため注意が必要です。
SRCNN, SRGANとディープラーニングを利用した超解像技術に触れてみましたが、すっかり超解像技術の虜になってしまいました(笑)
今後もまた超解像技術に触れてみたいと考えています。
ここまでの閲覧ありがとうございました。

7
5
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
5