LoginSignup
4
2

More than 3 years have passed since last update.

超解像技術-SRGAN-実装してみた(Tensorflow 2.0) 訓練フェーズ編

Last updated at Posted at 2020-12-05

概要

超解像技術とは、画像の解像度を高める技術です。
前回ではSRCNNを実装してみましたが、今回はSRGAN(Super-Resolution Generative Adversarial Network)を実装しました。

今回はSRGANの訓練フェーズ編です。
次回はSRGANの推論フェーズ編になります。

環境

-Software-
Windows 10 Home
Anaconda3 64-bit(Python3.7)
VSCode
-Library-
Tensorflow 2.2.0
opencv-python 4.1.2.30
-Hardware-
CPU: Intel core i9 9900K
GPU: NVIDIA GeForce RTX2080ti
RAM: 16GB 3200MHz

参考

ディープラーニングによる画像の拡大技術
Twitter社が発表した超解像ネットワークをchainerで再実装
SRGANをpytorchで実装してみた
SRGAN実装1(Keras)
SRGAN実装2(Keras)
https://www.metamaru.com/entry/2019/12/06/170934(11/08現在非公開)
https://medium.com/@crosssceneofwindff/srgan%E3%82%92%E7%94%A8%E3%81%84%E3%81%9F%E8%B6%85%E8%A7%A3%E5%83%8F-cf7fac7877294146(削除)

プログラム

GitHubに上げておきます。
https://github.com/himazin331/Super-resolution-GAN
リポジトリには訓練フェーズ、推論フェーズが含まれています。

今回は、データセットにGeneral-100を使いました。デモとして使えるようにGitHubのリポジトリにデータセットも上げてあります。

ソースコード

srgan_tr.py
import argparse as arg
import sys
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf

import tensorflow.keras.layers as kl
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.python.keras import backend as K
import cv2
import numpy as np
import matplotlib.pyplot as plt

# Super-resolution Image Generator
class Generator(tf.keras.Model):
    def __init__(self, input_shape):
        super().__init__()

        input_shape_ps = (input_shape[0], input_shape[1], 64) 

        # Pre stage(Down Sampling)
        self.pre = [
            kl.Conv2D(64, kernel_size=9, strides=1,
                    padding="same", input_shape=input_shape),
            kl.Activation(tf.nn.relu)
        ]

        # Residual Block
        self.res = [
            [
                Res_block(64, input_shape) for _ in range(7)
            ]
        ]

        # Middle stage
        self.middle = [
            kl.Conv2D(64, kernel_size=3, strides=1, padding="same"),
            kl.BatchNormalization()
        ]

        # Pixel Shuffle(Up Sampling)
        self.ps =[
            [
                Pixel_shuffler(128, input_shape_ps) for _ in range(2)
            ],
            kl.Conv2D(3, kernel_size=9, strides=4, padding="same", activation="tanh")
        ]

    def call(self, x):

        # Pre stage
        pre = x
        for layer in self.pre:
            pre = layer(pre)

        # Residual Block
        res = pre
        for layer in self.res:
            for l in layer:
                res = l(res)

        # Middle stage
        middle = res
        for layer in self.middle:
            middle = layer(middle)
        middle += pre

        # Pixel Shuffle
        out = middle
        for layer in self.ps:
            if isinstance(layer, list):
                for l in layer:
                    out = l(out)
            else:
                out = layer(out)

        return out

# Discriminator 
class Discriminator(tf.keras.Model):
    def __init__(self, input_shape):
        super().__init__()

        self.conv1 = kl.Conv2D(64, kernel_size=3, strides=1,
                            padding="same", input_shape=input_shape)
        self.act1 = kl.Activation(tf.nn.relu)

        self.conv2 = kl.Conv2D(64, kernel_size=3, strides=2,
                            padding="same")
        self.bn1 = kl.BatchNormalization()
        self.act2 = kl.LeakyReLU()

        self.conv3 = kl.Conv2D(128, kernel_size=3, strides=1,
                            padding="same")
        self.bn2 = kl.BatchNormalization()
        self.act3 = kl.LeakyReLU()

        self.conv4 = kl.Conv2D(128, kernel_size=3, strides=2,
                            padding="same")
        self.bn3 = kl.BatchNormalization()
        self.act4 = kl.LeakyReLU()

        self.conv5 = kl.Conv2D(256, kernel_size=3, strides=1,
                            padding="same")
        self.bn4 = kl.BatchNormalization()
        self.act5 = kl.LeakyReLU()

        self.conv6 = kl.Conv2D(256, kernel_size=3, strides=2,
                            padding="same")
        self.bn5 = kl.BatchNormalization()
        self.act6 = kl.LeakyReLU()

        self.conv7 = kl.Conv2D(512, kernel_size=3, strides=1,
                            padding="same")
        self.bn6 = kl.BatchNormalization()
        self.act7 = kl.LeakyReLU()

        self.conv8 = kl.Conv2D(512, kernel_size=3, strides=2,
                            padding="same")
        self.bn7 = kl.BatchNormalization()
        self.act8 = kl.LeakyReLU()

        self.flt = kl.Flatten()

        self.dens1 = kl.Dense(1024, activation=kl.LeakyReLU())
        self.dens2 = kl.Dense(1, activation="sigmoid")

    def call(self, x):

        d1 = self.act1(self.conv1(x))
        d2 = self.act2(self.bn1(self.conv2(d1)))
        d3 = self.act3(self.bn2(self.conv3(d2)))
        d4 = self.act4(self.bn3(self.conv4(d3)))
        d5 = self.act5(self.bn4(self.conv5(d4)))
        d6 = self.act6(self.bn5(self.conv6(d5)))
        d7 = self.act7(self.bn6(self.conv7(d6)))
        d8 = self.act8(self.bn7(self.conv8(d7)))

        d9 = self.dens1(self.flt(d8))
        d10 = self.dens2(d9)

        return d10

# Pixel Shuffle
class Pixel_shuffler(tf.keras.Model):
    def __init__(self, out_ch, input_shape):
        super().__init__()

        self.conv = kl.Conv2D(out_ch, kernel_size=3, strides=1,
                            padding="same", input_shape=input_shape)
        self.act = kl.Activation(tf.nn.relu)

    # forward proc
    def call(self, x):

        d1 = self.conv(x)
        d2 = self.act(tf.nn.depth_to_space(d1, 2))

        return d2

# Residual Block
class Res_block(tf.keras.Model):
    def __init__(self, ch, input_shape):
        super().__init__()

        self.conv1 = kl.Conv2D(ch, kernel_size=3, strides=1,
                            padding="same", input_shape=input_shape)
        self.bn1 = kl.BatchNormalization()
        self.av1 = kl.Activation(tf.nn.relu)

        self.conv2 = kl.Conv2D(ch, kernel_size=3, strides=1,
                            padding="same")
        self.bn2 = kl.BatchNormalization()

        self.add = kl.Add()

    def call(self, x):

        d1 = self.av1(self.bn1(self.conv1(x)))
        d2 = self.bn2(self.conv2(d1))

        return self.add([x, d2])

# Train
class trainer():
    def __init__(self, lr_img, hr_img):

        lr_shape = lr_img.shape # Low-resolution Image shape
        hr_shape = hr_img.shape # High-resolution Image shape

        # Content Loss Model setup
        input_tensor = tf.keras.Input(shape=hr_shape)
        self.vgg = VGG16(include_top=False, input_tensor=input_tensor)
        self.vgg.trainable = False
        self.vgg.outputs = [self.vgg.layers[9].output]  # VGG16 block3_conv3 output  

        # Content Loss Model
        self.cl_model = tf.keras.Model(input_tensor, self.vgg.outputs)

        # Discriminator
        discriminator_ = Discriminator(hr_shape)
        inputs = tf.keras.Input(shape=hr_shape)
        outputs = discriminator_(inputs)
        self.discriminator = tf.keras.Model(inputs=inputs, outputs=outputs)
        self.discriminator.compile(optimizer=tf.keras.optimizers.Adam(),
                                loss=tf.keras.losses.BinaryCrossentropy(),
                                metrics=['accuracy'])

        # Generator
        self.generator = Generator(lr_shape)

        # Combined Model setup
        lr_input = tf.keras.Input(shape=lr_shape)
        sr_output = self.generator(lr_input)

        self.discriminator.trainable = False # Discriminator train Disable
        d_fake = self.discriminator(sr_output)

        # SRGAN Model
        self.gan = tf.keras.Model(inputs=lr_input, outputs=[sr_output, d_fake])
        self.gan.compile(optimizer=tf.keras.optimizers.Adam(),
                        loss=[self.Content_loss, tf.keras.losses.BinaryCrossentropy()],
                        loss_weights=[1., 1e-3])

    # Content loss
    def Content_loss(self, hr_img, sr_img):
        return K.mean(K.abs(K.square(self.cl_model(hr_img) - self.cl_model(sr_img))))

    # PSNR
    def psnr(self, hr_img, sr_img):
        return cv2.PSNR(hr_img, sr_img)

    def train(self, lr_imgs, hr_imgs, out_path, batch_size, epoch):

        g_loss_plt = []
        d_loss_plt = []
        path = os.path.join(out_path, "graph.jpg")
        plt.figure(figsize=(12.8, 8.0), dpi=100)

        h_batch = int(batch_size / 2)

        real_lab = np.ones((h_batch, 1))  # High-resolution image label
        fake_lab = np.zeros((h_batch, 1)) # Super-resolution image label(Discriminator side)
        gan_lab = np.ones((h_batch, 1))

        # train run
        for epoch in range(epoch):

            # - Train Discriminator -

            # High-resolution image random pickups
            idx = np.random.randint(0, hr_imgs.shape[0], h_batch)
            hr_img = hr_imgs[idx]

            # Low-resolution image random pickups
            lr_img = lr_imgs[idx]

            # Discriminator enabled train
            self.discriminator.trainable = True

            # train by High-resolution image
            d_real_loss = self.discriminator.train_on_batch(hr_img, real_lab)

            # train by Super-resolution image
            sr_img = self.generator.predict(lr_img) 
            d_fake_loss = self.discriminator.train_on_batch(sr_img, fake_lab)

            # Discriminator average loss 
            d_loss = 0.5 * np.add(d_real_loss, d_fake_loss)

            # - Train Generator -

            # High-resolution image random pickups
            idx = np.random.randint(0, hr_imgs.shape[0], h_batch)
            hr_img = hr_imgs[idx]

            # Low-resolution image random pickups
            lr_img = lr_imgs[idx]

            # train by Generator
            self.discriminator.trainable = False
            g_loss = self.gan.train_on_batch(lr_img, [hr_img, gan_lab])

            # Epoch num, Discriminator/Generator loss, PSNR
            print("Epoch: {0} D_loss: {1:.3f} G_loss: {2:.3f} PSNR: {3:.3f}".format(epoch+1, d_loss[0], g_loss[0], self.psnr(hr_img, sr_img)))

            d_loss_plt.append(d_loss[0])
            g_loss_plt.append(g_loss[0])

            # Plotting the loss value
            if (epoch+1) % 50 == 0:
                plt.plot(d_loss_plt)
                plt.plot(g_loss_plt)
                plt.savefig(path, bbox_inches='tight', pad_inches=0.1)

        print("___Training finished\n\n")

        # Parameter-File and Graph Saving
        print("___Saving parameter...")
        self.generator.save_weights(os.path.join(out_path, "srgan.h5"))

        plt.plot(d_loss_plt, label="D_loss")
        plt.plot(g_loss_plt, label="G_loss")
        plt.savefig(path, bbox_inches='tight', pad_inches=0.1)
        print("___Successfully completed\n\n")

# Dataset creation
def create_dataset(data_dir, h, w, mag):

    print("\n___Creating a dataset...")

    prc = ['/', '-', '\\', '|']
    cnt = 0

    print("Number of image in a directory: {}".format(len(os.listdir(data_dir))))

    lr_imgs = []
    hr_imgs = []
    for c in os.listdir(data_dir):
        d = os.path.join(data_dir, c)

        _, ext = os.path.splitext(c)
        if ext.lower() not in ['.jpg', '.png', '.bmp']:
            continue

        img = cv2.imread(d)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (h, w)) # High-resolution image

        img_low = cv2.resize(img, (int(h/mag), int(w/mag))) # Image reduction
        img_low = cv2.resize(img_low, (h, w)) # Resize to original size

        lr_imgs.append(img_low)
        hr_imgs.append(img)

        cnt += 1

        print("\rLoading a LR-images and HR-images...{}    ({} / {})".format(prc[cnt%4], cnt, len(os.listdir(data_dir))), end='')

    print("\rLoading a LR-images and HR-images...Done    ({} / {})".format(cnt, len(os.listdir(data_dir))), end='')

    # Low-resolution image
    lr_imgs = tf.convert_to_tensor(lr_imgs, np.float32) 
    lr_imgs = (lr_imgs.numpy() - 127.5) / 127.5

    # High-resolution image
    hr_imgs = tf.convert_to_tensor(hr_imgs, np.float32)
    hr_imgs = (hr_imgs.numpy() - 127.5) / 127.5

    print("\n___Successfully completed\n")

    return lr_imgs, hr_imgs

def main():

    # Command line option
    parser = arg.ArgumentParser(description='Super-resolution GAN training')
    parser.add_argument('--data_dir', '-d', type=str, default=None,
                        help='Specify the image folder path (If not specified, an error)')
    parser.add_argument('--out', '-o', type=str,
                        default=os.path.dirname(os.path.abspath(__file__)),
                        help='Specify where to save parameters (default: ./srgan.h5)')
    parser.add_argument('--batch_size', '-b', type=int, default=32,
                        help='Specify the mini-batch size (default: 32)')
    parser.add_argument('--epoch', '-e', type=int, default=1000,
                        help='Specify the number of times to train (default: 1000)')
    parser.add_argument('--he', '-he', type=int, default=128,
                        help='Resize height (default: 128)')      
    parser.add_argument('--wi', '-wi', type=int, default=128,
                        help='Resize width (default: 128)')
    parser.add_argument('--mag', '-m', type=int, default=2,
                        help='Magnification (default: 2)')                           
    args = parser.parse_args()

    # Image folder not specified. -> Exception
    if args.data_dir == None:
        print("\nException: Folder not specified.\n")
        sys.exit()
    # An image folder that does not exist was specified. -> Exception
    if os.path.exists(args.data_dir) != True:
        print("\nException: Folder \"{}\" is not found.\n".format(args.data_dir))
        sys.exit()
    # When 0 is entered for either width/height or Reduction ratio. -> Exception
    if args.he == 0 or args.wi == 0 or args.mag == 0:
        print("\nException: Invalid value has been entered.\n")
        sys.exit()

    # Create output folder (If the folder exists, it will not be created.)
    os.makedirs(args.out, exist_ok=True)

    # Setting info
    print("=== Setting information ===")
    print("# Images folder: {}".format(os.path.abspath(args.data_dir)))
    print("# Output folder: {}".format(args.out))
    print("# Minibatch-size: {}".format(args.batch_size))
    print("# Epoch: {}".format(args.epoch))
    print("")
    print("# Height: {}".format(args.he))
    print("# Width: {}".format(args.wi))
    print("# Magnification: {}".format(args.mag))
    print("===========================")

    # dataset creation
    lr_imgs, hr_imgs = create_dataset(args.data_dir, args.he, args.wi, args.mag)

    print("___Start training...")
    Trainer = trainer(lr_imgs[0], hr_imgs[0])
    Trainer.train(lr_imgs, hr_imgs, out_path=args.out, batch_size=args.batch_size, epoch=args.epoch)

if __name__ == '__main__':
    main()

注意

このSRGANはVRAMをかなり専有するので、

2020-08-15 20:28:53.386530: E tensorflow/stream_executor/cuda/cuda_driver.cc:825] failed to alloc 4294967296 bytes on host: CUDA_ERROR_OUT_OF_MEMORY: out of memory
2020-08-15 20:28:53.386709: E tensorflow/stream_executor/cuda/cuda_driver.cc:825] failed to alloc 3865470464 bytes on host: CUDA_ERROR_OUT_OF_MEMORY: out of memory
Traceback (most recent call last):
  ~中略~
tensorflow.python.framework.errors_impl.ResourceExhaustedError:  OOM when allocating tensor with shape[16,128,512,512] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
         [[node model_2/generator/pixel_shuffer_1/conv2d_25/Conv2D (defined at srgan_tr.py:158) ]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.
 [Op:__inference_train_function_10485]

Errors may have originated from an input operation.
Input Source operations connected to node model_2/generator/pixel_shuffer_1/conv2d_25/Conv2D:
 model_2/generator/pixel_shuffer/activation_9/Relu (defined at srgan_tr.py:159)

Function call stack:
train_function

環境や設定によっては、上のようなResourceExhaustedErrorがでるかと思います。
私の環境(RTX2080ti)でも256x256の画像でミニバッチ数32を指定すると、このようなエラーが出て実行することができませんでした。

こういったエラーが出た場合は、ミニバッチ数を小さくするか、Google Colaboratoryで実行するなどの妥協策しかないようです。
(無理にメモリ解放をすると断片化するらしいので非推奨です。)

実行コマンド

python srgan_tr.py -d <フォルダ> -e <学習回数> -b <バッチサイズ> (-o <保存先> -he <高さ> -wi <幅> -m <縮小倍率(整数)>)

説明

コードの説明をしていきます。

アーキテクチャはSRGANをpytorchで実装してみたを参考にさせていただきました。

SRGANは名前の通り、GANが使われています。
低解像度画像を元にGeneratorで高解像度な画像を作り出し、Generatorが作り出した画像なのかオリジナル画像(訓練データ)なのかをDiscriminatorで鑑別します。
この処理を繰り返し、Generatorが作り出す画像(超解像画像)の質を高めていきます。

また、Generatorではより細かい特徴量を抽出するために多層に構築します。
そこで勾配消失を防ぐためにSkip connectionを用いてます。

Generator

低解像度画像を元に高解像度画像を生成します。

入力画像をダウンサンプリングし、残差ブロックによる特徴量の抽出を行います。
その後、残差ブロックの出力とダウンサンプリングの出力とでSkip connectionを結んだ後、Pixel Shufflerによるアップサンプリングを行います。

Generator
# Super-resolution Image Generator
class Generator(tf.keras.Model):
    def __init__(self, input_shape):
        super().__init__()

        input_shape_ps = (input_shape[0], input_shape[1], 64) 

        # Pre stage(Down Sampling)
        self.pre = [
            kl.Conv2D(64, kernel_size=9, strides=1,
                    padding="same", input_shape=input_shape),
            kl.Activation(tf.nn.relu)
        ]

        # Residual Block
        self.res = [
            [
                Res_block(64, input_shape) for _ in range(7)
            ]
        ]

        # Middle stage
        self.middle = [
            kl.Conv2D(64, kernel_size=3, strides=1, padding="same"),
            kl.BatchNormalization()
        ]

        # Pixel Shuffle(Up Sampling)
        self.ps =[
            [
                Pixel_shuffler(128, input_shape_ps) for _ in range(2)
            ],
            kl.Conv2D(3, kernel_size=9, strides=4, padding="same", activation="tanh")
        ]

    def call(self, x):

        # Pre stage
        pre = x
        for layer in self.pre:
            pre = layer(pre)

        # Residual Block
        res = pre
        for layer in self.res:
            for l in layer:
                res = l(res)

        # Middle stage
        middle = res
        for layer in self.middle:
            middle = layer(middle)
        middle += pre

        # Pixel Shuffle
        out = middle
        for layer in self.ps:
            if isinstance(layer, list):
                for l in layer:
                    out = l(out)
            else:
                out = layer(out)

        return out

アーキテクチャは以下の通りです。

Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_4 (InputLayer)            [(None, 128, 128, 3) 0
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 128, 128, 64) 15616       input_4[0][0]
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 128, 128, 64) 0           conv2d_8[0][0]
__________________________________________________________________________________________________
res_block (Res_block)           (None, 128, 128, 64) 74368       activation_1[0][0]
__________________________________________________________________________________________________
res_block_1 (Res_block)         (None, 128, 128, 64) 74368       res_block[0][0]
__________________________________________________________________________________________________
res_block_2 (Res_block)         (None, 128, 128, 64) 74368       res_block_1[0][0]
__________________________________________________________________________________________________
res_block_3 (Res_block)         (None, 128, 128, 64) 74368       res_block_2[0][0]
__________________________________________________________________________________________________
res_block_4 (Res_block)         (None, 128, 128, 64) 74368       res_block_3[0][0]
__________________________________________________________________________________________________
res_block_5 (Res_block)         (None, 128, 128, 64) 74368       res_block_4[0][0]
__________________________________________________________________________________________________
res_block_6 (Res_block)         (None, 128, 128, 64) 74368       res_block_5[0][0]
__________________________________________________________________________________________________
conv2d_23 (Conv2D)              (None, 128, 128, 64) 36928       res_block_6[0][0]
__________________________________________________________________________________________________
batch_normalization_21 (BatchNo (None, 128, 128, 64) 256         conv2d_23[0][0]
__________________________________________________________________________________________________
tf_op_layer_AddV2 (TensorFlowOp [(None, 128, 128, 64 0           batch_normalization_21[0][0]
                                                                 activation_1[0][0]
__________________________________________________________________________________________________
pixel_shuffler (Pixel_shuffler) (None, 256, 256, 32) 73856       tf_op_layer_AddV2[0][0]
__________________________________________________________________________________________________
pixel_shuffler_1 (Pixel_shuffle (None, 512, 512, 32) 36992       pixel_shuffler[0][0]
__________________________________________________________________________________________________
conv2d_26 (Conv2D)              (None, 128, 128, 3)  7779        pixel_shuffler_1[0][0]
==================================================================================================
Total params: 692,003
Trainable params: 690,083
Non-trainable params: 1,920
__________________________________________________________________________________________________

Residual Block

残差ブロックです。ResNetで使われてるやつです。
特に説明はしなくていいでしょう。

Res_block
# Residual Block
class Res_block(tf.keras.Model):
    def __init__(self, ch, input_shape):
        super().__init__()

        self.conv1 = kl.Conv2D(ch, kernel_size=3, strides=1,
                            padding="same", input_shape=input_shape)
        self.bn1 = kl.BatchNormalization()
        self.av1 = kl.Activation(tf.nn.relu)

        self.conv2 = kl.Conv2D(ch, kernel_size=3, strides=1,
                            padding="same")
        self.bn2 = kl.BatchNormalization()

        self.add = kl.Add()

    def call(self, x):

        d1 = self.av1(self.bn1(self.conv1(x)))
        d2 = self.bn2(self.conv2(d1))

        return self.add([x, d2])

Pixel Shuffler

Pixel Shufflerについてはこちらで詳しく解説されています。

従来、アップサンプリングでDeconvolutionという手法が使われていましたが、Deconvolutionは計算速度が遅く、Checkerboard Artifactとよばれる格子状の模様ができてしまう(詳細)という問題を抱えています。

そのため、近年ではDeconvolutionに代わってPixel Shufflerが用いられているそうです。

Pixel_shuffler
# Pixel Shuffle
class Pixel_shuffler(tf.keras.Model):
    def __init__(self, out_ch, input_shape):
        super().__init__()

        self.conv = kl.Conv2D(out_ch, kernel_size=3, strides=1,
                            padding="same", input_shape=input_shape)
        self.act = kl.Activation(tf.nn.relu)

    # forward proc
    def call(self, x):

        d1 = self.conv(x)
        d2 = self.act(tf.nn.depth_to_space(d1, 2))

        return d2

Discriminator

入力画像がGeneratorにより作成された超解像画像(偽物)と訓練データであるオリジナル画像(本物)のどちらであるか判別します。

Discriminator
# Discriminator 
class Discriminator(tf.keras.Model):
    def __init__(self, input_shape):
        super().__init__()

        self.conv1 = kl.Conv2D(64, kernel_size=3, strides=1,
                            padding="same", input_shape=input_shape)
        self.act1 = kl.Activation(tf.nn.relu)

        self.conv2 = kl.Conv2D(64, kernel_size=3, strides=2,
                            padding="same")
        self.bn1 = kl.BatchNormalization()
        self.act2 = kl.LeakyReLU()

        self.conv3 = kl.Conv2D(128, kernel_size=3, strides=1,
                            padding="same")
        self.bn2 = kl.BatchNormalization()
        self.act3 = kl.LeakyReLU()

        self.conv4 = kl.Conv2D(128, kernel_size=3, strides=2,
                            padding="same")
        self.bn3 = kl.BatchNormalization()
        self.act4 = kl.LeakyReLU()

        self.conv5 = kl.Conv2D(256, kernel_size=3, strides=1,
                            padding="same")
        self.bn4 = kl.BatchNormalization()
        self.act5 = kl.LeakyReLU()

        self.conv6 = kl.Conv2D(256, kernel_size=3, strides=2,
                            padding="same")
        self.bn5 = kl.BatchNormalization()
        self.act6 = kl.LeakyReLU()

        self.conv7 = kl.Conv2D(512, kernel_size=3, strides=1,
                            padding="same")
        self.bn6 = kl.BatchNormalization()
        self.act7 = kl.LeakyReLU()

        self.conv8 = kl.Conv2D(512, kernel_size=3, strides=2,
                            padding="same")
        self.bn7 = kl.BatchNormalization()
        self.act8 = kl.LeakyReLU()

        self.flt = kl.Flatten()

        self.dens1 = kl.Dense(1024, activation=kl.LeakyReLU())
        self.dens2 = kl.Dense(1, activation="sigmoid")

    def call(self, x):

        d1 = self.act1(self.conv1(x))
        d2 = self.act2(self.bn1(self.conv2(d1)))
        d3 = self.act3(self.bn2(self.conv3(d2)))
        d4 = self.act4(self.bn3(self.conv4(d3)))
        d5 = self.act5(self.bn4(self.conv5(d4)))
        d6 = self.act6(self.bn5(self.conv6(d5)))
        d7 = self.act7(self.bn6(self.conv7(d6)))
        d8 = self.act8(self.bn7(self.conv8(d7)))

        d9 = self.dens1(self.flt(d8))
        d10 = self.dens2(d9)

        return d10

アーキテクチャは以下の通り。

_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_2 (InputLayer)         [(None, 128, 128, 3)]     0
_________________________________________________________________
conv2d (Conv2D)              (None, 128, 128, 64)      1792
_________________________________________________________________
activation (Activation)      (None, 128, 128, 64)      0
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 64, 64, 64)        36928
_________________________________________________________________
batch_normalization (BatchNo (None, 64, 64, 64)        256
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 64, 64, 64)        0
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 64, 64, 128)       73856
_________________________________________________________________
batch_normalization_1 (Batch (None, 64, 64, 128)       512
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 64, 64, 128)       0
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 32, 32, 128)       147584
_________________________________________________________________
batch_normalization_2 (Batch (None, 32, 32, 128)       512
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 32, 32, 128)       0
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 32, 32, 256)       295168
_________________________________________________________________
batch_normalization_3 (Batch (None, 32, 32, 256)       1024
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 32, 32, 256)       0
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 16, 16, 256)       590080
_________________________________________________________________
batch_normalization_4 (Batch (None, 16, 16, 256)       1024
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 16, 16, 256)       0
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 16, 16, 512)       1180160
_________________________________________________________________
batch_normalization_5 (Batch (None, 16, 16, 512)       2048
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 16, 16, 512)       0
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 8, 8, 512)         2359808
_________________________________________________________________
batch_normalization_6 (Batch (None, 8, 8, 512)         2048
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 8, 8, 512)         0
_________________________________________________________________
flatten (Flatten)            (None, 32768)             0
_________________________________________________________________
dense (Dense)                (None, 1024)              33555456
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 1025
=================================================================
Total params: 38,249,281
Trainable params: 38,245,569
Non-trainable params: 3,712
_________________________________________________________________

学習

trainerクラスではモデルの構築や学習を担います。

__ init __メソッド

モデルの構築など学習の前準備を行います。

trainerクラス(__init__メソッド)
    def __init__(self, lr_img, hr_img):

        lr_shape = lr_img.shape # Low-resolution Image shape
        hr_shape = hr_img.shape # High-resolution Image shape

        # Content Loss Model setup
        input_tensor = tf.keras.Input(shape=hr_shape)
        self.vgg = VGG16(include_top=False, input_tensor=input_tensor)
        self.vgg.trainable = False
        self.vgg.outputs = [self.vgg.layers[9].output]  # VGG16 block3_conv3 output  

        # Content Loss Model
        self.cl_model = tf.keras.Model(input_tensor, self.vgg.outputs)

        # Discriminator
        discriminator_ = Discriminator(hr_shape)
        inputs = tf.keras.Input(shape=hr_shape)
        outputs = discriminator_(inputs)
        self.discriminator = tf.keras.Model(inputs=inputs, outputs=outputs)
        self.discriminator.compile(optimizer=tf.keras.optimizers.Adam(),
                                loss=tf.keras.losses.BinaryCrossentropy(),
                                metrics=['accuracy'])

        # Generator
        self.generator = Generator(lr_shape)

        # Combined Model setup
        lr_input = tf.keras.Input(shape=lr_shape)
        sr_output = self.generator(lr_input)

        self.discriminator.trainable = False # Discriminator train Disable
        d_fake = self.discriminator(sr_output)

        # SRGAN Model
        self.gan = tf.keras.Model(inputs=lr_input, outputs=[sr_output, d_fake])
        self.gan.compile(optimizer=tf.keras.optimizers.Adam(),
                        loss=[self.Content_loss, tf.keras.losses.BinaryCrossentropy()],
                        loss_weights=[1., 1e-3])

このメソッドでやっていることは以下の3つです。
1. Content Loss算出で用いるモデルの構築
2. Discriminator構築
3. Combined Model(Generator+Discriminator)構築
順を追って説明していきます。

1. Content Loss算出で用いるモデルの構築

Content Loss(詳細は後述)の算出に使うネットワークモデルの構築を行います。

        # Content Loss Model setup
        input_tensor = tf.keras.Input(shape=hr_shape)
        self.vgg = VGG16(include_top=False, input_tensor=input_tensor)
        self.vgg.trainable = False
        self.vgg.outputs = [self.vgg.layers[9].output]  # VGG16 block3_conv3 output  

        # Content Loss Model
        self.cl_model = tf.keras.Model(input_tensor, self.vgg.outputs)

今回は学習済みVGG16モデルの10番目の出力をContent Lossにて使います

2. Discriminator構築

Discriminatorの構築を行います。

        # Discriminator
        discriminator_ = Discriminator(hr_shape)
        inputs = tf.keras.Input(shape=hr_shape)
        outputs = discriminator_(inputs)
        self.discriminator = tf.keras.Model(inputs=inputs, outputs=outputs)
        self.discriminator.compile(optimizer=tf.keras.optimizers.Adam(),
                                loss=tf.keras.losses.BinaryCrossentropy(),
                                metrics=['accuracy'])

Discriminatorクラスのインスタンスdiscriminator_に入力層を付与して、compileします。
loss_weightやlearning_rateはデフォルトのままにしてます。(めんどくさいので笑)

3. Combined Model(Generator+Discriminator)構築

GeneratorとDiscriminatorをあわせたCombined Modelを構築します。

        # Generator
        self.generator = Generator(lr_shape)

        # Combined Model setup
        lr_input = tf.keras.Input(shape=lr_shape)
        sr_output = self.generator(lr_input)

        self.discriminator.trainable = False
        d_fake = self.discriminator(sr_output)

        self.gan = tf.keras.Model(inputs=lr_input, outputs=[sr_output, d_fake])
        self.gan.compile(optimizer=tf.keras.optimizers.Adam(),
                        loss=[self.Content_loss, tf.keras.losses.BinaryCrossentropy()],
                        loss_weights=[1., 1e-3])

self.generatorに低解像度画像lr_inputを渡して超解像画像の出力サイズsr_outputを取得し、
それをself.discriminatorに渡してself.discriminatorの出力サイズd_fakeを得ます。
これらの出力サイズを用いて、Combined Modelのoutputsを定義します。

compileのlossについてですが、
outputsのsr_outputに対しては、超解像画像sr_outputとオリジナル画像(訓練データ)との誤差を求めるself.Content_loss(後述)を指定してやります。
d_fakeに対しては、Discriminatorの判別結果d_fakeと正解ラベルとの誤差を求めるBinaryCrossentropyを指定してやります。

Content_lossメソッド

Content Lossについてはこちらの記事がわかりやすいと思います。
軽く説明すると、ただ高解像度画像と超解像画像とで平均二乗誤差をとると出力結果がぼやけてしまうため、訓練済みネットワークの中間層の出力を使います。
高解像度画像と超解像画像を訓練済みネットワークに流し、それぞれの出力で平均二乗誤差をとります。
高解像度画像から抽出した特徴量と一致すれば、超解像画像は高解像度画像の特徴を持っていると言え、超解像画像は高解像度画像に近しくなっていると言えるという原理です。

trainerクラス(Content_lossメソッド)
    # Content loss
    def Content_loss(self, hr_img, sr_img):
        return K.mean(K.abs(K.square(self.cl_model(hr_img) - self.cl_model(sr_img))))

ここではVGG16の10番目の出力を用いて平均二乗誤差をとっています。

PSNRメソッド

画像における再現性の品質の尺度であるPSNR(Peak Signal-to-Noise Ratio, ピーク信号対雑音比)というものがあります。これは信号(画素)が取りうる最大のパワー(ピーク信号)と劣化をもたらすノイズ(雑音)の比率を表します。

単位はdB(デシベル)で、値が高いほど品質がよいとされてます。
しかし、必ずしもPSNR値と人間が感じる情報が一致するとは限らないです。

定義式

$$PSNR = 10 \log_{10}\frac{MAX^2}{MSE}$$

$MAX$は信号(画素)がとり得る最大のパワー(ピーク信号 = 最大値)。

$MSE$は平均二乗誤差。

$$MSE = \frac{1}{n} \sum_{i=1}^{n} (SR_i - HR_i)^2$$

$SR$は生成画像の画素ベクトル、$HR$は訓練データの画素ベクトルです。

平均二乗誤差の値がノイズの程度を表し、
MAXがピーク信号を表していて、それらを除算したものの対数が比率となります。

trainerクラス(PSNRメソッド)
    # PSNR
    def psnr(self, hr_img, sr_img):
        return cv2.PSNR(hr_img, sr_img)

OpenCVにあるPSNRメソッドを用いて算出しています。

trainメソッド

このメソッドで実際に学習を行っています。

trainerクラス(trainメソッド)
    def train(self, lr_imgs, hr_imgs, out_path, batch_size, epoch):

        g_loss_plt = []
        d_loss_plt = []
        path = os.path.join(out_path, "graph.jpg")
        plt.figure(figsize=(12.8, 8.0), dpi=100)

        h_batch = int(batch_size / 2)

        real_lab = np.ones((h_batch, 1))  # High-resolution image label
        fake_lab = np.zeros((h_batch, 1)) # Super-resolution image label(Discriminator side)
        gan_lab = np.ones((h_batch, 1))

        # train run
        for epoch in range(epoch):

            # - Train Discriminator -

            # High-resolution image random pickups
            idx = np.random.randint(0, hr_imgs.shape[0], h_batch)
            hr_img = hr_imgs[idx]

            # Low-resolution image random pickups
            lr_img = lr_imgs[idx]

            # Discriminator enabled train
            self.discriminator.trainable = True

            # train by High-resolution image
            d_real_loss = self.discriminator.train_on_batch(hr_img, real_lab)

            # train by Super-resolution image
            sr_img = self.generator.predict(lr_img) 
            d_fake_loss = self.discriminator.train_on_batch(sr_img, fake_lab)

            # Discriminator average loss 
            d_loss = 0.5 * np.add(d_real_loss, d_fake_loss)

            # - Train Generator -

            # High-resolution image random pickups
            idx = np.random.randint(0, hr_imgs.shape[0], h_batch)
            hr_img = hr_imgs[idx]

            # Low-resolution image random pickups
            lr_img = lr_imgs[idx]

            # train by Generator
            self.discriminator.trainable = False
            g_loss = self.gan.train_on_batch(lr_img, [hr_img, gan_lab])

            # Epoch num, Discriminator/Generator loss, PSNR
            print("Epoch: {0} D_loss: {1:.3f} G_loss: {2:.3f} PSNR: {3:.3f}".format(epoch+1, d_loss[0], g_loss[0], self.psnr(hr_img, sr_img)))

            d_loss_plt.append(d_loss[0])
            g_loss_plt.append(g_loss[0])

            # Plotting the loss value
            if (epoch+1) % 50 == 0:
                plt.plot(d_loss_plt)
                plt.plot(g_loss_plt)
                plt.savefig(path, bbox_inches='tight', pad_inches=0.1)

        print("___Training finished\n\n")

        # Parameter-File and Graph Saving
        print("___Saving parameter...")
        self.generator.save_weights(os.path.join(out_path, "srgan.h5"))

        plt.plot(d_loss_plt, label="D_loss")
        plt.plot(g_loss_plt, label="G_loss")
        plt.savefig(path, bbox_inches='tight', pad_inches=0.1)
        print("___Successfully completed\n\n")

細かい部分の説明は省いて、DiscriminatorとGeneratorの学習部分だけ説明します。

まずは、Discriminatorの学習です。
最初に高解像度画像、低解像度画像からそれぞれバッチサイズの半分の数だけ取り出します。

            # - Train Discriminator -

            # High-resolution image random pickups
            idx = np.random.randint(0, hr_imgs.shape[0], h_batch)
            hr_img = hr_imgs[idx]

            # Low-resolution image random pickups
            lr_img = lr_imgs[idx]

次にDiscriminatorに高解像度画像を学習させます。

            # train by High-resolution image
            d_real_loss = self.discriminator.train_on_batch(hr_img, real_lab)

hr_img高解像度画像real_labすべて 1 のラベルとなります。
これのバイナリ交差エントロピー誤差を取って学習していきます。

高解像度画像を学習させたら、超解像画像を学習させます。

            # train by Super-resolution image
            sr_img = self.generator.predict(lr_img) 
            d_fake_loss = self.discriminator.train_on_batch(sr_img, fake_lab)

Generatorに低解像度画像を流して、超解像画像sr_imgを得ます。
その超解像画像sr_imgとすべて 0 のラベルfake_labをDiscriminatorに渡して学習していきます。

これでDiscriminatorの学習は完了しました。
次にGeneratorの学習です。
厳密にはCombined Model(コード中ではgan)に対しての学習ですが、
Combined Model中のDiscriminatorの学習はしないように設定して、Generatorのみを学習します。

まず、Discriminator同様に高解像度画像、低解像度画像からそれぞれバッチサイズの半分の数だけ取り出します。

            # High-resolution image random pickups
            idx = np.random.randint(0, hr_imgs.shape[0], h_batch)
            hr_img = hr_imgs[idx]

            # Low-resolution image random pickups
            lr_img = lr_imgs[idx]

Generatorの学習ですが、先でも説明したとおり、Discriminatorの学習はしないということで、
self.discriminator.trainable = Falseと記述して学習の無効化をします。
無効化したら、低解像度画像lr_imgと正解データに高解像度画像hr_imgとすべて 1 のラベルgan_lab

            # train by Generator
            self.discriminator.trainable = False
            g_loss = self.gan.train_on_batch(lr_img, [hr_img, gan_lab])

Generatorが出力した超解像画像と高解像度画像とでContent Lossを取り、超解像画像とすべて 1 のラベルとでバイナリ交差エントロピー誤差を取ります。

これをEpoch数分繰り返して精度を高めていきます。


データセット作成

必要とするデータは高解像度画像のみで大丈夫です。
低解像度画像は高解像度画像から作成します。

create_dataset関数
# Dataset creation
def create_dataset(data_dir, h, w, mag):

    print("\n___Creating a dataset...")

    prc = ['/', '-', '\\', '|']
    cnt = 0

    print("Number of image in a directory: {}".format(len(os.listdir(data_dir))))

    lr_imgs = []
    hr_imgs = []
    for c in os.listdir(data_dir):
        d = os.path.join(data_dir, c)

        _, ext = os.path.splitext(c)
        if ext.lower() not in ['.jpg', '.png', '.bmp']:
            continue

        img = cv2.imread(d)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (h, w)) # High-resolution image

        img_low = cv2.resize(img, (int(h/mag), int(w/mag))) # Image reduction
        img_low = cv2.resize(img_low, (h, w)) # Resize to original size

        lr_imgs.append(img_low)
        hr_imgs.append(img)

        cnt += 1

        print("\rLoading a LR-images and HR-images...{}    ({} / {})".format(prc[cnt%4], cnt, len(os.listdir(data_dir))), end='')

    print("\rLoading a LR-images and HR-images...Done    ({} / {})".format(cnt, len(os.listdir(data_dir))), end='')

    # Low-resolution image
    lr_imgs = tf.convert_to_tensor(lr_imgs, np.float32) 
    lr_imgs = (lr_imgs.numpy() - 127.5) / 127.5

    # High-resolution image
    hr_imgs = tf.convert_to_tensor(hr_imgs, np.float32)
    hr_imgs = (hr_imgs.numpy() - 127.5) / 127.5

    print("\n___Successfully completed\n")

    return lr_imgs, hr_imgs

まず、画像を読み込みます。OpenCVで読み込んだ場合、画素の並びがBGRとなるため、
cv2.cvtColor(img, cv2.COLOR_BGR2RGB)でRGBに変換します。その後、指定したサイズにリサイズを行います。
これで高解像度画像の準備はひとまずOKです。

次に、低解像度画像を作成します。指定した縮小倍率で割った幅・高さに縮小します。
その後、縮小する前のサイズにリサイズし直せば作成完了です。

        img = cv2.imread(d)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (h, w)) # High-resolution image

        img_low = cv2.resize(img, (int(h/mag), int(w/mag))) # Image reduction
        img_low = cv2.resize(img_low, (h, w)) # Resize to original size

図を使って表すとこんな感じです。
image.png

蛇足ですが、OpenCVのcv2.resize()の補間アルゴリズムはデフォルトではBilinearが使われます。

おわりに

前回、SRCNNの記事を書いて、思いの外反応が得られました。
なかなか、忙しくてQiitaにまとめ上げられなかったのですが、やっとSRGANの記事を書き上げることができました。

次回はSRGAN 推論フェーズ編を投稿する予定ですので、よろしければ見ていただければなと思います。

4
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
2