###概要
超解像技術とは、画像の解像度を高める技術です。
前回ではSRCNNを実装してみましたが、今回はSRGAN(Super-Resolution Generative Adversarial Network)を実装しました。
今回はSRGANの訓練フェーズ編です。
次回はSRGANの推論フェーズ編になります。
###環境
-Software-
Windows 10 Home
Anaconda3 64-bit(Python3.7)
VSCode
-Library-
Tensorflow 2.2.0
opencv-python 4.1.2.30
-Hardware-
CPU: Intel core i9 9900K
GPU: NVIDIA GeForce RTX2080ti
RAM: 16GB 3200MHz
###参考
ディープラーニングによる画像の拡大技術
[Twitter社が発表した超解像ネットワークをchainerで再実装]
(http://hi-king.hatenablog.com/entry/2016/12/18/094146)
[SRGANをpytorchで実装してみた]
(https://qiita.com/pacifinapacific/items/ec338a500015ae8c33fe)
[SRGAN実装1(Keras)]
(https://github.com/deepak112/Keras-SRGAN/blob/master/train.py)
[SRGAN実装2(Keras)]
(https://github.com/eriklindernoren/Keras-GAN/blob/master/srgan/srgan.py)
~~https://www.metamaru.com/entry/2019/12/06/170934~~(11/08現在非公開)
~~https://medium.com/@crosssceneofwindff/srgan%E3%82%92%E7%94%A8%E3%81%84%E3%81%9F%E8%B6%85%E8%A7%A3%E5%83%8F-cf7fac7877294146~~(削除)
###プログラム
GitHubに上げておきます。
https://github.com/himazin331/Super-resolution-GAN
リポジトリには訓練フェーズ、推論フェーズが含まれています。
今回は、データセットにGeneral-100を使いました。デモとして使えるようにGitHubのリポジトリにデータセットも上げてあります。
###ソースコード
import argparse as arg
import sys
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import tensorflow.keras.layers as kl
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.python.keras import backend as K
import cv2
import numpy as np
import matplotlib.pyplot as plt
# Super-resolution Image Generator
class Generator(tf.keras.Model):
def __init__(self, input_shape):
super().__init__()
input_shape_ps = (input_shape[0], input_shape[1], 64)
# Pre stage(Down Sampling)
self.pre = [
kl.Conv2D(64, kernel_size=9, strides=1,
padding="same", input_shape=input_shape),
kl.Activation(tf.nn.relu)
]
# Residual Block
self.res = [
[
Res_block(64, input_shape) for _ in range(7)
]
]
# Middle stage
self.middle = [
kl.Conv2D(64, kernel_size=3, strides=1, padding="same"),
kl.BatchNormalization()
]
# Pixel Shuffle(Up Sampling)
self.ps =[
[
Pixel_shuffler(128, input_shape_ps) for _ in range(2)
],
kl.Conv2D(3, kernel_size=9, strides=4, padding="same", activation="tanh")
]
def call(self, x):
# Pre stage
pre = x
for layer in self.pre:
pre = layer(pre)
# Residual Block
res = pre
for layer in self.res:
for l in layer:
res = l(res)
# Middle stage
middle = res
for layer in self.middle:
middle = layer(middle)
middle += pre
# Pixel Shuffle
out = middle
for layer in self.ps:
if isinstance(layer, list):
for l in layer:
out = l(out)
else:
out = layer(out)
return out
# Discriminator
class Discriminator(tf.keras.Model):
def __init__(self, input_shape):
super().__init__()
self.conv1 = kl.Conv2D(64, kernel_size=3, strides=1,
padding="same", input_shape=input_shape)
self.act1 = kl.Activation(tf.nn.relu)
self.conv2 = kl.Conv2D(64, kernel_size=3, strides=2,
padding="same")
self.bn1 = kl.BatchNormalization()
self.act2 = kl.LeakyReLU()
self.conv3 = kl.Conv2D(128, kernel_size=3, strides=1,
padding="same")
self.bn2 = kl.BatchNormalization()
self.act3 = kl.LeakyReLU()
self.conv4 = kl.Conv2D(128, kernel_size=3, strides=2,
padding="same")
self.bn3 = kl.BatchNormalization()
self.act4 = kl.LeakyReLU()
self.conv5 = kl.Conv2D(256, kernel_size=3, strides=1,
padding="same")
self.bn4 = kl.BatchNormalization()
self.act5 = kl.LeakyReLU()
self.conv6 = kl.Conv2D(256, kernel_size=3, strides=2,
padding="same")
self.bn5 = kl.BatchNormalization()
self.act6 = kl.LeakyReLU()
self.conv7 = kl.Conv2D(512, kernel_size=3, strides=1,
padding="same")
self.bn6 = kl.BatchNormalization()
self.act7 = kl.LeakyReLU()
self.conv8 = kl.Conv2D(512, kernel_size=3, strides=2,
padding="same")
self.bn7 = kl.BatchNormalization()
self.act8 = kl.LeakyReLU()
self.flt = kl.Flatten()
self.dens1 = kl.Dense(1024, activation=kl.LeakyReLU())
self.dens2 = kl.Dense(1, activation="sigmoid")
def call(self, x):
d1 = self.act1(self.conv1(x))
d2 = self.act2(self.bn1(self.conv2(d1)))
d3 = self.act3(self.bn2(self.conv3(d2)))
d4 = self.act4(self.bn3(self.conv4(d3)))
d5 = self.act5(self.bn4(self.conv5(d4)))
d6 = self.act6(self.bn5(self.conv6(d5)))
d7 = self.act7(self.bn6(self.conv7(d6)))
d8 = self.act8(self.bn7(self.conv8(d7)))
d9 = self.dens1(self.flt(d8))
d10 = self.dens2(d9)
return d10
# Pixel Shuffle
class Pixel_shuffler(tf.keras.Model):
def __init__(self, out_ch, input_shape):
super().__init__()
self.conv = kl.Conv2D(out_ch, kernel_size=3, strides=1,
padding="same", input_shape=input_shape)
self.act = kl.Activation(tf.nn.relu)
# forward proc
def call(self, x):
d1 = self.conv(x)
d2 = self.act(tf.nn.depth_to_space(d1, 2))
return d2
# Residual Block
class Res_block(tf.keras.Model):
def __init__(self, ch, input_shape):
super().__init__()
self.conv1 = kl.Conv2D(ch, kernel_size=3, strides=1,
padding="same", input_shape=input_shape)
self.bn1 = kl.BatchNormalization()
self.av1 = kl.Activation(tf.nn.relu)
self.conv2 = kl.Conv2D(ch, kernel_size=3, strides=1,
padding="same")
self.bn2 = kl.BatchNormalization()
self.add = kl.Add()
def call(self, x):
d1 = self.av1(self.bn1(self.conv1(x)))
d2 = self.bn2(self.conv2(d1))
return self.add([x, d2])
# Train
class trainer():
def __init__(self, lr_img, hr_img):
lr_shape = lr_img.shape # Low-resolution Image shape
hr_shape = hr_img.shape # High-resolution Image shape
# Content Loss Model setup
input_tensor = tf.keras.Input(shape=hr_shape)
self.vgg = VGG16(include_top=False, input_tensor=input_tensor)
self.vgg.trainable = False
self.vgg.outputs = [self.vgg.layers[9].output] # VGG16 block3_conv3 output
# Content Loss Model
self.cl_model = tf.keras.Model(input_tensor, self.vgg.outputs)
# Discriminator
discriminator_ = Discriminator(hr_shape)
inputs = tf.keras.Input(shape=hr_shape)
outputs = discriminator_(inputs)
self.discriminator = tf.keras.Model(inputs=inputs, outputs=outputs)
self.discriminator.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=['accuracy'])
# Generator
self.generator = Generator(lr_shape)
# Combined Model setup
lr_input = tf.keras.Input(shape=lr_shape)
sr_output = self.generator(lr_input)
self.discriminator.trainable = False # Discriminator train Disable
d_fake = self.discriminator(sr_output)
# SRGAN Model
self.gan = tf.keras.Model(inputs=lr_input, outputs=[sr_output, d_fake])
self.gan.compile(optimizer=tf.keras.optimizers.Adam(),
loss=[self.Content_loss, tf.keras.losses.BinaryCrossentropy()],
loss_weights=[1., 1e-3])
# Content loss
def Content_loss(self, hr_img, sr_img):
return K.mean(K.abs(K.square(self.cl_model(hr_img) - self.cl_model(sr_img))))
# PSNR
def psnr(self, hr_img, sr_img):
return cv2.PSNR(hr_img, sr_img)
def train(self, lr_imgs, hr_imgs, out_path, batch_size, epoch):
g_loss_plt = []
d_loss_plt = []
path = os.path.join(out_path, "graph.jpg")
plt.figure(figsize=(12.8, 8.0), dpi=100)
h_batch = int(batch_size / 2)
real_lab = np.ones((h_batch, 1)) # High-resolution image label
fake_lab = np.zeros((h_batch, 1)) # Super-resolution image label(Discriminator side)
gan_lab = np.ones((h_batch, 1))
# train run
for epoch in range(epoch):
# - Train Discriminator -
# High-resolution image random pickups
idx = np.random.randint(0, hr_imgs.shape[0], h_batch)
hr_img = hr_imgs[idx]
# Low-resolution image random pickups
lr_img = lr_imgs[idx]
# Discriminator enabled train
self.discriminator.trainable = True
# train by High-resolution image
d_real_loss = self.discriminator.train_on_batch(hr_img, real_lab)
# train by Super-resolution image
sr_img = self.generator.predict(lr_img)
d_fake_loss = self.discriminator.train_on_batch(sr_img, fake_lab)
# Discriminator average loss
d_loss = 0.5 * np.add(d_real_loss, d_fake_loss)
# - Train Generator -
# High-resolution image random pickups
idx = np.random.randint(0, hr_imgs.shape[0], h_batch)
hr_img = hr_imgs[idx]
# Low-resolution image random pickups
lr_img = lr_imgs[idx]
# train by Generator
self.discriminator.trainable = False
g_loss = self.gan.train_on_batch(lr_img, [hr_img, gan_lab])
# Epoch num, Discriminator/Generator loss, PSNR
print("Epoch: {0} D_loss: {1:.3f} G_loss: {2:.3f} PSNR: {3:.3f}".format(epoch+1, d_loss[0], g_loss[0], self.psnr(hr_img, sr_img)))
d_loss_plt.append(d_loss[0])
g_loss_plt.append(g_loss[0])
# Plotting the loss value
if (epoch+1) % 50 == 0:
plt.plot(d_loss_plt)
plt.plot(g_loss_plt)
plt.savefig(path, bbox_inches='tight', pad_inches=0.1)
print("___Training finished\n\n")
# Parameter-File and Graph Saving
print("___Saving parameter...")
self.generator.save_weights(os.path.join(out_path, "srgan.h5"))
plt.plot(d_loss_plt, label="D_loss")
plt.plot(g_loss_plt, label="G_loss")
plt.savefig(path, bbox_inches='tight', pad_inches=0.1)
print("___Successfully completed\n\n")
# Dataset creation
def create_dataset(data_dir, h, w, mag):
print("\n___Creating a dataset...")
prc = ['/', '-', '\\', '|']
cnt = 0
print("Number of image in a directory: {}".format(len(os.listdir(data_dir))))
lr_imgs = []
hr_imgs = []
for c in os.listdir(data_dir):
d = os.path.join(data_dir, c)
_, ext = os.path.splitext(c)
if ext.lower() not in ['.jpg', '.png', '.bmp']:
continue
img = cv2.imread(d)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (h, w)) # High-resolution image
img_low = cv2.resize(img, (int(h/mag), int(w/mag))) # Image reduction
img_low = cv2.resize(img_low, (h, w)) # Resize to original size
lr_imgs.append(img_low)
hr_imgs.append(img)
cnt += 1
print("\rLoading a LR-images and HR-images...{} ({} / {})".format(prc[cnt%4], cnt, len(os.listdir(data_dir))), end='')
print("\rLoading a LR-images and HR-images...Done ({} / {})".format(cnt, len(os.listdir(data_dir))), end='')
# Low-resolution image
lr_imgs = tf.convert_to_tensor(lr_imgs, np.float32)
lr_imgs = (lr_imgs.numpy() - 127.5) / 127.5
# High-resolution image
hr_imgs = tf.convert_to_tensor(hr_imgs, np.float32)
hr_imgs = (hr_imgs.numpy() - 127.5) / 127.5
print("\n___Successfully completed\n")
return lr_imgs, hr_imgs
def main():
# Command line option
parser = arg.ArgumentParser(description='Super-resolution GAN training')
parser.add_argument('--data_dir', '-d', type=str, default=None,
help='Specify the image folder path (If not specified, an error)')
parser.add_argument('--out', '-o', type=str,
default=os.path.dirname(os.path.abspath(__file__)),
help='Specify where to save parameters (default: ./srgan.h5)')
parser.add_argument('--batch_size', '-b', type=int, default=32,
help='Specify the mini-batch size (default: 32)')
parser.add_argument('--epoch', '-e', type=int, default=1000,
help='Specify the number of times to train (default: 1000)')
parser.add_argument('--he', '-he', type=int, default=128,
help='Resize height (default: 128)')
parser.add_argument('--wi', '-wi', type=int, default=128,
help='Resize width (default: 128)')
parser.add_argument('--mag', '-m', type=int, default=2,
help='Magnification (default: 2)')
args = parser.parse_args()
# Image folder not specified. -> Exception
if args.data_dir == None:
print("\nException: Folder not specified.\n")
sys.exit()
# An image folder that does not exist was specified. -> Exception
if os.path.exists(args.data_dir) != True:
print("\nException: Folder \"{}\" is not found.\n".format(args.data_dir))
sys.exit()
# When 0 is entered for either width/height or Reduction ratio. -> Exception
if args.he == 0 or args.wi == 0 or args.mag == 0:
print("\nException: Invalid value has been entered.\n")
sys.exit()
# Create output folder (If the folder exists, it will not be created.)
os.makedirs(args.out, exist_ok=True)
# Setting info
print("=== Setting information ===")
print("# Images folder: {}".format(os.path.abspath(args.data_dir)))
print("# Output folder: {}".format(args.out))
print("# Minibatch-size: {}".format(args.batch_size))
print("# Epoch: {}".format(args.epoch))
print("")
print("# Height: {}".format(args.he))
print("# Width: {}".format(args.wi))
print("# Magnification: {}".format(args.mag))
print("===========================")
# dataset creation
lr_imgs, hr_imgs = create_dataset(args.data_dir, args.he, args.wi, args.mag)
print("___Start training...")
Trainer = trainer(lr_imgs[0], hr_imgs[0])
Trainer.train(lr_imgs, hr_imgs, out_path=args.out, batch_size=args.batch_size, epoch=args.epoch)
if __name__ == '__main__':
main()
###注意
このSRGANはVRAMをかなり専有するので、
2020-08-15 20:28:53.386530: E tensorflow/stream_executor/cuda/cuda_driver.cc:825] failed to alloc 4294967296 bytes on host: CUDA_ERROR_OUT_OF_MEMORY: out of memory
2020-08-15 20:28:53.386709: E tensorflow/stream_executor/cuda/cuda_driver.cc:825] failed to alloc 3865470464 bytes on host: CUDA_ERROR_OUT_OF_MEMORY: out of memory
Traceback (most recent call last):
~中略~
tensorflow.python.framework.errors_impl.ResourceExhaustedError: OOM when allocating tensor with shape[16,128,512,512] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
[[node model_2/generator/pixel_shuffer_1/conv2d_25/Conv2D (defined at srgan_tr.py:158) ]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.
[Op:__inference_train_function_10485]
Errors may have originated from an input operation.
Input Source operations connected to node model_2/generator/pixel_shuffer_1/conv2d_25/Conv2D:
model_2/generator/pixel_shuffer/activation_9/Relu (defined at srgan_tr.py:159)
Function call stack:
train_function
環境や設定によっては、上のようなResourceExhaustedError
がでるかと思います。
私の環境(RTX2080ti)でも256x256の画像でミニバッチ数32を指定すると、このようなエラーが出て実行することができませんでした。
こういったエラーが出た場合は、ミニバッチ数を小さくするか、Google Colaboratoryで実行するなどの妥協策しかないようです。
(無理にメモリ解放をすると断片化するらしいので非推奨です。)
###実行コマンド
python srgan_tr.py -d <フォルダ> -e <学習回数> -b <バッチサイズ> (-o <保存先> -he <高さ> -wi <幅> -m <縮小倍率(整数)>)
###説明
コードの説明をしていきます。
アーキテクチャは[SRGANをpytorchで実装してみた]
(https://qiita.com/pacifinapacific/items/ec338a500015ae8c33fe)を参考にさせていただきました。
SRGANは名前の通り、GANが使われています。
低解像度画像を元にGeneratorで高解像度な画像を作り出し、Generatorが作り出した画像なのかオリジナル画像(訓練データ)なのかをDiscriminatorで鑑別します。
この処理を繰り返し、Generatorが作り出す画像(超解像画像)の質を高めていきます。
また、Generatorではより細かい特徴量を抽出するために多層に構築します。
そこで勾配消失を防ぐためにSkip connectionを用いてます。
####Generator
低解像度画像を元に高解像度画像を生成します。
入力画像をダウンサンプリングし、残差ブロックによる特徴量の抽出を行います。
その後、残差ブロックの出力とダウンサンプリングの出力とでSkip connectionを結んだ後、Pixel Shufflerによるアップサンプリングを行います。
# Super-resolution Image Generator
class Generator(tf.keras.Model):
def __init__(self, input_shape):
super().__init__()
input_shape_ps = (input_shape[0], input_shape[1], 64)
# Pre stage(Down Sampling)
self.pre = [
kl.Conv2D(64, kernel_size=9, strides=1,
padding="same", input_shape=input_shape),
kl.Activation(tf.nn.relu)
]
# Residual Block
self.res = [
[
Res_block(64, input_shape) for _ in range(7)
]
]
# Middle stage
self.middle = [
kl.Conv2D(64, kernel_size=3, strides=1, padding="same"),
kl.BatchNormalization()
]
# Pixel Shuffle(Up Sampling)
self.ps =[
[
Pixel_shuffler(128, input_shape_ps) for _ in range(2)
],
kl.Conv2D(3, kernel_size=9, strides=4, padding="same", activation="tanh")
]
def call(self, x):
# Pre stage
pre = x
for layer in self.pre:
pre = layer(pre)
# Residual Block
res = pre
for layer in self.res:
for l in layer:
res = l(res)
# Middle stage
middle = res
for layer in self.middle:
middle = layer(middle)
middle += pre
# Pixel Shuffle
out = middle
for layer in self.ps:
if isinstance(layer, list):
for l in layer:
out = l(out)
else:
out = layer(out)
return out
アーキテクチャは以下の通りです。
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_4 (InputLayer) [(None, 128, 128, 3) 0
__________________________________________________________________________________________________
conv2d_8 (Conv2D) (None, 128, 128, 64) 15616 input_4[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 128, 128, 64) 0 conv2d_8[0][0]
__________________________________________________________________________________________________
res_block (Res_block) (None, 128, 128, 64) 74368 activation_1[0][0]
__________________________________________________________________________________________________
res_block_1 (Res_block) (None, 128, 128, 64) 74368 res_block[0][0]
__________________________________________________________________________________________________
res_block_2 (Res_block) (None, 128, 128, 64) 74368 res_block_1[0][0]
__________________________________________________________________________________________________
res_block_3 (Res_block) (None, 128, 128, 64) 74368 res_block_2[0][0]
__________________________________________________________________________________________________
res_block_4 (Res_block) (None, 128, 128, 64) 74368 res_block_3[0][0]
__________________________________________________________________________________________________
res_block_5 (Res_block) (None, 128, 128, 64) 74368 res_block_4[0][0]
__________________________________________________________________________________________________
res_block_6 (Res_block) (None, 128, 128, 64) 74368 res_block_5[0][0]
__________________________________________________________________________________________________
conv2d_23 (Conv2D) (None, 128, 128, 64) 36928 res_block_6[0][0]
__________________________________________________________________________________________________
batch_normalization_21 (BatchNo (None, 128, 128, 64) 256 conv2d_23[0][0]
__________________________________________________________________________________________________
tf_op_layer_AddV2 (TensorFlowOp [(None, 128, 128, 64 0 batch_normalization_21[0][0]
activation_1[0][0]
__________________________________________________________________________________________________
pixel_shuffler (Pixel_shuffler) (None, 256, 256, 32) 73856 tf_op_layer_AddV2[0][0]
__________________________________________________________________________________________________
pixel_shuffler_1 (Pixel_shuffle (None, 512, 512, 32) 36992 pixel_shuffler[0][0]
__________________________________________________________________________________________________
conv2d_26 (Conv2D) (None, 128, 128, 3) 7779 pixel_shuffler_1[0][0]
==================================================================================================
Total params: 692,003
Trainable params: 690,083
Non-trainable params: 1,920
__________________________________________________________________________________________________
####Residual Block
残差ブロックです。ResNetで使われてるやつです。
特に説明はしなくていいでしょう。
# Residual Block
class Res_block(tf.keras.Model):
def __init__(self, ch, input_shape):
super().__init__()
self.conv1 = kl.Conv2D(ch, kernel_size=3, strides=1,
padding="same", input_shape=input_shape)
self.bn1 = kl.BatchNormalization()
self.av1 = kl.Activation(tf.nn.relu)
self.conv2 = kl.Conv2D(ch, kernel_size=3, strides=1,
padding="same")
self.bn2 = kl.BatchNormalization()
self.add = kl.Add()
def call(self, x):
d1 = self.av1(self.bn1(self.conv1(x)))
d2 = self.bn2(self.conv2(d1))
return self.add([x, d2])
####Pixel Shuffler
Pixel Shufflerについてはこちらで詳しく解説されています。
従来、アップサンプリングでDeconvolutionという手法が使われていましたが、Deconvolutionは計算速度が遅く、Checkerboard Artifactとよばれる格子状の模様ができてしまう(詳細)という問題を抱えています。
そのため、近年ではDeconvolutionに代わってPixel Shufflerが用いられているそうです。
# Pixel Shuffle
class Pixel_shuffler(tf.keras.Model):
def __init__(self, out_ch, input_shape):
super().__init__()
self.conv = kl.Conv2D(out_ch, kernel_size=3, strides=1,
padding="same", input_shape=input_shape)
self.act = kl.Activation(tf.nn.relu)
# forward proc
def call(self, x):
d1 = self.conv(x)
d2 = self.act(tf.nn.depth_to_space(d1, 2))
return d2
####Discriminator
入力画像がGeneratorにより作成された超解像画像(偽物)と訓練データであるオリジナル画像(本物)のどちらであるか判別します。
# Discriminator
class Discriminator(tf.keras.Model):
def __init__(self, input_shape):
super().__init__()
self.conv1 = kl.Conv2D(64, kernel_size=3, strides=1,
padding="same", input_shape=input_shape)
self.act1 = kl.Activation(tf.nn.relu)
self.conv2 = kl.Conv2D(64, kernel_size=3, strides=2,
padding="same")
self.bn1 = kl.BatchNormalization()
self.act2 = kl.LeakyReLU()
self.conv3 = kl.Conv2D(128, kernel_size=3, strides=1,
padding="same")
self.bn2 = kl.BatchNormalization()
self.act3 = kl.LeakyReLU()
self.conv4 = kl.Conv2D(128, kernel_size=3, strides=2,
padding="same")
self.bn3 = kl.BatchNormalization()
self.act4 = kl.LeakyReLU()
self.conv5 = kl.Conv2D(256, kernel_size=3, strides=1,
padding="same")
self.bn4 = kl.BatchNormalization()
self.act5 = kl.LeakyReLU()
self.conv6 = kl.Conv2D(256, kernel_size=3, strides=2,
padding="same")
self.bn5 = kl.BatchNormalization()
self.act6 = kl.LeakyReLU()
self.conv7 = kl.Conv2D(512, kernel_size=3, strides=1,
padding="same")
self.bn6 = kl.BatchNormalization()
self.act7 = kl.LeakyReLU()
self.conv8 = kl.Conv2D(512, kernel_size=3, strides=2,
padding="same")
self.bn7 = kl.BatchNormalization()
self.act8 = kl.LeakyReLU()
self.flt = kl.Flatten()
self.dens1 = kl.Dense(1024, activation=kl.LeakyReLU())
self.dens2 = kl.Dense(1, activation="sigmoid")
def call(self, x):
d1 = self.act1(self.conv1(x))
d2 = self.act2(self.bn1(self.conv2(d1)))
d3 = self.act3(self.bn2(self.conv3(d2)))
d4 = self.act4(self.bn3(self.conv4(d3)))
d5 = self.act5(self.bn4(self.conv5(d4)))
d6 = self.act6(self.bn5(self.conv6(d5)))
d7 = self.act7(self.bn6(self.conv7(d6)))
d8 = self.act8(self.bn7(self.conv8(d7)))
d9 = self.dens1(self.flt(d8))
d10 = self.dens2(d9)
return d10
アーキテクチャは以下の通り。
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 128, 128, 3)] 0
_________________________________________________________________
conv2d (Conv2D) (None, 128, 128, 64) 1792
_________________________________________________________________
activation (Activation) (None, 128, 128, 64) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 64, 64, 64) 36928
_________________________________________________________________
batch_normalization (BatchNo (None, 64, 64, 64) 256
_________________________________________________________________
leaky_re_lu (LeakyReLU) (None, 64, 64, 64) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 64, 64, 128) 73856
_________________________________________________________________
batch_normalization_1 (Batch (None, 64, 64, 128) 512
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 64, 64, 128) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 32, 32, 128) 147584
_________________________________________________________________
batch_normalization_2 (Batch (None, 32, 32, 128) 512
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None, 32, 32, 128) 0
_________________________________________________________________
conv2d_4 (Conv2D) (None, 32, 32, 256) 295168
_________________________________________________________________
batch_normalization_3 (Batch (None, 32, 32, 256) 1024
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 32, 32, 256) 0
_________________________________________________________________
conv2d_5 (Conv2D) (None, 16, 16, 256) 590080
_________________________________________________________________
batch_normalization_4 (Batch (None, 16, 16, 256) 1024
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU) (None, 16, 16, 256) 0
_________________________________________________________________
conv2d_6 (Conv2D) (None, 16, 16, 512) 1180160
_________________________________________________________________
batch_normalization_5 (Batch (None, 16, 16, 512) 2048
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU) (None, 16, 16, 512) 0
_________________________________________________________________
conv2d_7 (Conv2D) (None, 8, 8, 512) 2359808
_________________________________________________________________
batch_normalization_6 (Batch (None, 8, 8, 512) 2048
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU) (None, 8, 8, 512) 0
_________________________________________________________________
flatten (Flatten) (None, 32768) 0
_________________________________________________________________
dense (Dense) (None, 1024) 33555456
_________________________________________________________________
dense_1 (Dense) (None, 1) 1025
=================================================================
Total params: 38,249,281
Trainable params: 38,245,569
Non-trainable params: 3,712
_________________________________________________________________
####学習
trainer
クラスではモデルの構築や学習を担います。
#####__ init __メソッド
モデルの構築など学習の前準備を行います。
def __init__(self, lr_img, hr_img):
lr_shape = lr_img.shape # Low-resolution Image shape
hr_shape = hr_img.shape # High-resolution Image shape
# Content Loss Model setup
input_tensor = tf.keras.Input(shape=hr_shape)
self.vgg = VGG16(include_top=False, input_tensor=input_tensor)
self.vgg.trainable = False
self.vgg.outputs = [self.vgg.layers[9].output] # VGG16 block3_conv3 output
# Content Loss Model
self.cl_model = tf.keras.Model(input_tensor, self.vgg.outputs)
# Discriminator
discriminator_ = Discriminator(hr_shape)
inputs = tf.keras.Input(shape=hr_shape)
outputs = discriminator_(inputs)
self.discriminator = tf.keras.Model(inputs=inputs, outputs=outputs)
self.discriminator.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=['accuracy'])
# Generator
self.generator = Generator(lr_shape)
# Combined Model setup
lr_input = tf.keras.Input(shape=lr_shape)
sr_output = self.generator(lr_input)
self.discriminator.trainable = False # Discriminator train Disable
d_fake = self.discriminator(sr_output)
# SRGAN Model
self.gan = tf.keras.Model(inputs=lr_input, outputs=[sr_output, d_fake])
self.gan.compile(optimizer=tf.keras.optimizers.Adam(),
loss=[self.Content_loss, tf.keras.losses.BinaryCrossentropy()],
loss_weights=[1., 1e-3])
このメソッドでやっていることは以下の3つです。
1. Content Loss算出で用いるモデルの構築
2. Discriminator構築
3. Combined Model(Generator+Discriminator)構築
順を追って説明していきます。
######1. Content Loss算出で用いるモデルの構築
Content Loss(詳細は後述)の算出に使うネットワークモデルの構築を行います。
# Content Loss Model setup
input_tensor = tf.keras.Input(shape=hr_shape)
self.vgg = VGG16(include_top=False, input_tensor=input_tensor)
self.vgg.trainable = False
self.vgg.outputs = [self.vgg.layers[9].output] # VGG16 block3_conv3 output
# Content Loss Model
self.cl_model = tf.keras.Model(input_tensor, self.vgg.outputs)
今回は学習済みVGG16モデルの10番目の出力をContent Lossにて使います。
######2. Discriminator構築
Discriminatorの構築を行います。
# Discriminator
discriminator_ = Discriminator(hr_shape)
inputs = tf.keras.Input(shape=hr_shape)
outputs = discriminator_(inputs)
self.discriminator = tf.keras.Model(inputs=inputs, outputs=outputs)
self.discriminator.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=['accuracy'])
Discriminatorクラスのインスタンスdiscriminator_
に入力層を付与して、compileします。
loss_weightやlearning_rateはデフォルトのままにしてます。(めんどくさいので笑)
######3. Combined Model(Generator+Discriminator)構築
GeneratorとDiscriminatorをあわせたCombined Modelを構築します。
# Generator
self.generator = Generator(lr_shape)
# Combined Model setup
lr_input = tf.keras.Input(shape=lr_shape)
sr_output = self.generator(lr_input)
self.discriminator.trainable = False
d_fake = self.discriminator(sr_output)
self.gan = tf.keras.Model(inputs=lr_input, outputs=[sr_output, d_fake])
self.gan.compile(optimizer=tf.keras.optimizers.Adam(),
loss=[self.Content_loss, tf.keras.losses.BinaryCrossentropy()],
loss_weights=[1., 1e-3])
self.generator
に低解像度画像lr_input
を渡して**超解像画像の出力サイズsr_output
を取得し、
それをself.discriminator
に渡してself.discriminator
の出力サイズd_fake
**を得ます。
これらの出力サイズを用いて、Combined Modelのoutputsを定義します。
compileのlossについてですが、
outputsの**sr_output
に対しては、超解像画像sr_output
とオリジナル画像(訓練データ)との誤差を求めるself.Content_loss
**(後述)を指定してやります。
d_fake
に対しては、Discriminatorの判別結果d_fake
と正解ラベルとの誤差を求めるBinaryCrossentropyを指定してやります。
#####Content_lossメソッド
Content Lossについてはこちらの記事がわかりやすいと思います。
軽く説明すると、ただ高解像度画像と超解像画像とで平均二乗誤差をとると出力結果がぼやけてしまうため、訓練済みネットワークの中間層の出力を使います。
高解像度画像と超解像画像を訓練済みネットワークに流し、それぞれの出力で平均二乗誤差をとります。
高解像度画像から抽出した特徴量と一致すれば、超解像画像は高解像度画像の特徴を持っていると言え、超解像画像は高解像度画像に近しくなっていると言えるという原理です。
# Content loss
def Content_loss(self, hr_img, sr_img):
return K.mean(K.abs(K.square(self.cl_model(hr_img) - self.cl_model(sr_img))))
ここではVGG16の10番目の出力を用いて平均二乗誤差をとっています。
#####PSNRメソッド
画像における再現性の品質の尺度であるPSNR(Peak Signal-to-Noise Ratio, ピーク信号対雑音比)というものがあります。これは信号(画素)が取りうる最大のパワー(ピーク信号)と劣化をもたらすノイズ(雑音)の比率を表します。
単位はdB(デシベル)で、値が高いほど品質がよいとされてます。
しかし、必ずしもPSNR値と人間が感じる情報が一致するとは限らないです。
定義式
$$PSNR = 10 \log_{10}\frac{MAX^2}{MSE}$$
$MAX$は信号(画素)がとり得る最大のパワー(ピーク信号 = 最大値)。
$MSE$は平均二乗誤差。
$$MSE = \frac{1}{n} \sum_{i=1}^{n} (SR_i - HR_i)^2$$
$SR$は生成画像の画素ベクトル、$HR$は訓練データの画素ベクトルです。
平均二乗誤差の値がノイズの程度を表し、
MAXがピーク信号を表していて、それらを除算したものの対数が比率となります。
# PSNR
def psnr(self, hr_img, sr_img):
return cv2.PSNR(hr_img, sr_img)
OpenCVにあるPSNRメソッドを用いて算出しています。
#####trainメソッド
このメソッドで実際に学習を行っています。
def train(self, lr_imgs, hr_imgs, out_path, batch_size, epoch):
g_loss_plt = []
d_loss_plt = []
path = os.path.join(out_path, "graph.jpg")
plt.figure(figsize=(12.8, 8.0), dpi=100)
h_batch = int(batch_size / 2)
real_lab = np.ones((h_batch, 1)) # High-resolution image label
fake_lab = np.zeros((h_batch, 1)) # Super-resolution image label(Discriminator side)
gan_lab = np.ones((h_batch, 1))
# train run
for epoch in range(epoch):
# - Train Discriminator -
# High-resolution image random pickups
idx = np.random.randint(0, hr_imgs.shape[0], h_batch)
hr_img = hr_imgs[idx]
# Low-resolution image random pickups
lr_img = lr_imgs[idx]
# Discriminator enabled train
self.discriminator.trainable = True
# train by High-resolution image
d_real_loss = self.discriminator.train_on_batch(hr_img, real_lab)
# train by Super-resolution image
sr_img = self.generator.predict(lr_img)
d_fake_loss = self.discriminator.train_on_batch(sr_img, fake_lab)
# Discriminator average loss
d_loss = 0.5 * np.add(d_real_loss, d_fake_loss)
# - Train Generator -
# High-resolution image random pickups
idx = np.random.randint(0, hr_imgs.shape[0], h_batch)
hr_img = hr_imgs[idx]
# Low-resolution image random pickups
lr_img = lr_imgs[idx]
# train by Generator
self.discriminator.trainable = False
g_loss = self.gan.train_on_batch(lr_img, [hr_img, gan_lab])
# Epoch num, Discriminator/Generator loss, PSNR
print("Epoch: {0} D_loss: {1:.3f} G_loss: {2:.3f} PSNR: {3:.3f}".format(epoch+1, d_loss[0], g_loss[0], self.psnr(hr_img, sr_img)))
d_loss_plt.append(d_loss[0])
g_loss_plt.append(g_loss[0])
# Plotting the loss value
if (epoch+1) % 50 == 0:
plt.plot(d_loss_plt)
plt.plot(g_loss_plt)
plt.savefig(path, bbox_inches='tight', pad_inches=0.1)
print("___Training finished\n\n")
# Parameter-File and Graph Saving
print("___Saving parameter...")
self.generator.save_weights(os.path.join(out_path, "srgan.h5"))
plt.plot(d_loss_plt, label="D_loss")
plt.plot(g_loss_plt, label="G_loss")
plt.savefig(path, bbox_inches='tight', pad_inches=0.1)
print("___Successfully completed\n\n")
細かい部分の説明は省いて、DiscriminatorとGeneratorの学習部分だけ説明します。
まずは、Discriminatorの学習です。
最初に高解像度画像、低解像度画像からそれぞれバッチサイズの半分の数だけ取り出します。
# - Train Discriminator -
# High-resolution image random pickups
idx = np.random.randint(0, hr_imgs.shape[0], h_batch)
hr_img = hr_imgs[idx]
# Low-resolution image random pickups
lr_img = lr_imgs[idx]
次にDiscriminatorに高解像度画像を学習させます。
# train by High-resolution image
d_real_loss = self.discriminator.train_on_batch(hr_img, real_lab)
hr_img
が高解像度画像でreal_lab
がすべて 1 のラベルとなります。
これのバイナリ交差エントロピー誤差を取って学習していきます。
高解像度画像を学習させたら、超解像画像を学習させます。
# train by Super-resolution image
sr_img = self.generator.predict(lr_img)
d_fake_loss = self.discriminator.train_on_batch(sr_img, fake_lab)
Generatorに低解像度画像を流して、超解像画像sr_img
を得ます。
その超解像画像sr_img
とすべて 0 のラベルfake_lab
をDiscriminatorに渡して学習していきます。
これでDiscriminatorの学習は完了しました。
次にGeneratorの学習です。
厳密にはCombined Model(コード中ではgan
)に対しての学習ですが、
Combined Model中のDiscriminatorの学習はしないように設定して、Generatorのみを学習します。
まず、Discriminator同様に高解像度画像、低解像度画像からそれぞれバッチサイズの半分の数だけ取り出します。
# High-resolution image random pickups
idx = np.random.randint(0, hr_imgs.shape[0], h_batch)
hr_img = hr_imgs[idx]
# Low-resolution image random pickups
lr_img = lr_imgs[idx]
Generatorの学習ですが、先でも説明したとおり、Discriminatorの学習はしないということで、
self.discriminator.trainable = False
と記述して学習の無効化をします。
無効化したら、低解像度画像lr_img
と正解データに高解像度画像hr_img
とすべて 1 のラベルgan_lab
# train by Generator
self.discriminator.trainable = False
g_loss = self.gan.train_on_batch(lr_img, [hr_img, gan_lab])
Generatorが出力した超解像画像と高解像度画像とでContent Lossを取り、超解像画像とすべて 1 のラベルとでバイナリ交差エントロピー誤差を取ります。
これをEpoch数分繰り返して精度を高めていきます。
####データセット作成
必要とするデータは高解像度画像のみで大丈夫です。
低解像度画像は高解像度画像から作成します。
# Dataset creation
def create_dataset(data_dir, h, w, mag):
print("\n___Creating a dataset...")
prc = ['/', '-', '\\', '|']
cnt = 0
print("Number of image in a directory: {}".format(len(os.listdir(data_dir))))
lr_imgs = []
hr_imgs = []
for c in os.listdir(data_dir):
d = os.path.join(data_dir, c)
_, ext = os.path.splitext(c)
if ext.lower() not in ['.jpg', '.png', '.bmp']:
continue
img = cv2.imread(d)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (h, w)) # High-resolution image
img_low = cv2.resize(img, (int(h/mag), int(w/mag))) # Image reduction
img_low = cv2.resize(img_low, (h, w)) # Resize to original size
lr_imgs.append(img_low)
hr_imgs.append(img)
cnt += 1
print("\rLoading a LR-images and HR-images...{} ({} / {})".format(prc[cnt%4], cnt, len(os.listdir(data_dir))), end='')
print("\rLoading a LR-images and HR-images...Done ({} / {})".format(cnt, len(os.listdir(data_dir))), end='')
# Low-resolution image
lr_imgs = tf.convert_to_tensor(lr_imgs, np.float32)
lr_imgs = (lr_imgs.numpy() - 127.5) / 127.5
# High-resolution image
hr_imgs = tf.convert_to_tensor(hr_imgs, np.float32)
hr_imgs = (hr_imgs.numpy() - 127.5) / 127.5
print("\n___Successfully completed\n")
return lr_imgs, hr_imgs
まず、画像を読み込みます。OpenCVで読み込んだ場合、画素の並びがBGRとなるため、
cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
でRGBに変換します。その後、指定したサイズにリサイズを行います。
これで高解像度画像の準備はひとまずOKです。
次に、低解像度画像を作成します。指定した縮小倍率で割った幅・高さに縮小します。
その後、縮小する前のサイズにリサイズし直せば作成完了です。
img = cv2.imread(d)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (h, w)) # High-resolution image
img_low = cv2.resize(img, (int(h/mag), int(w/mag))) # Image reduction
img_low = cv2.resize(img_low, (h, w)) # Resize to original size
蛇足ですが、OpenCVのcv2.resize()
の補間アルゴリズムはデフォルトではBilinearが使われます。
###おわりに
前回、SRCNNの記事を書いて、思いの外反応が得られました。
なかなか、忙しくてQiitaにまとめ上げられなかったのですが、やっとSRGANの記事を書き上げることができました。
次回はSRGAN 推論フェーズ編を投稿する予定ですので、よろしければ見ていただければなと思います。