#1. ライブラリとPSNR指標
ピーク信号対雑音比(PSNR)を「解像度の高さ」の指標とする。
この値が20以上くらいになると比較的、解像度が良いとされている。
画像に対しては概ね 30dB 以上であれば元の画質を高品質に保っている。
import numpy as np
import cv2
import keras
from keras.datasets import cifar10
from sklearn.model_selection import train_test_split
from keras.models import Sequential
from keras import backend as K
from keras import Input, Model, callbacks, layers, models
from keras.layers import (Add, BatchNormalization, Conv2D, Dense, LeakyReLU, MaxPooling2D, UpSampling2D)
from keras.layers.core import Activation, Dense
from keras.optimizers import Adam
def psnr(y_true, y_pred):
return -10*K.log(K.mean(K.flatten((y_true - y_pred))**2))/np.log(10)
#2. 学習データとテストデータの準備
学習データとして、16×16ピクセルの画像を用意する。
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
#荒い画像データの用意
input_train=[]
input_test=[]
for i in range(50000):
sample1=x_train[i].copy()
sample1=cv2.resize(sample1, dsize=(16,16))
sample1=cv2.resize(sample1, dsize=(32,32), interpolation=cv2.INTER_NEAREST)
input_train.append(sample1)
for j in range(10000):
sample2=x_test[j].copy()
sample2=cv2.resize(sample2, dsize=(16,16))
sample2=cv2.resize(sample2, dsize=(32,32), interpolation=cv2.INTER_NEAREST)
input_test.append(sample2)
#データを正規化
x_train_normalized = x_train.astype('float16')/255
x_test_normalized = x_test.astype('float16')/255
Input_train = np.asarray(input_train)
Input_test = np.asarray(input_test)
Input_train_normalized = Input_train.astype('float16')/255
Input_test_normalized = Input_test.astype('float16')/255
#3. モデルの定義
inputs = Input((32, 32, 3))
X = inputs
shortcut = X
#メイン
X = Conv2D(filters=16, kernel_size=(3,3), activation='relu', padding='same')(inputs)
X = Conv2D(filters=32, kernel_size=(3,3), activation='relu', padding='same')(X)
X = Conv2D(filters=64, kernel_size=(3,3), activation='relu', padding='same')(X)
X = Conv2D(filters=3, kernel_size=(1,1))(X)
#ショートカット接続
X = Add()([X, shortcut])
outputs = X
#モデル
model = Model(inputs=[inputs], outputs=[outputs])
model.summary()
model.compile(optimizer = Adam(learning_rate = 0.0002), loss = 'mse', metrics = [psnr])
fit_callbacs = [callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, mode = 'min')]
hist = model.fit(Input_train_normalized, x_train_normalized, epochs=30, batch_size=32, shuffle = True,
validation_data = (Input_test_normalized, x_test_normalized), callbacks = fit_callbacs, verbose=1)
val_loss, val_psnr = model.evaluate(Input_test_normalized, x_test_normalized, verbose=1)
#4. 画像の表示
import matplotlib.pyplot as plt
import matplotlib.cm as cm
p = np.random.randint(0, len(x_test), 10)
#オリジナル画像
x_test_sampled = x_test[p]
y_test_sample_lable = y_test[p]
#画質が荒い入力画像
Input_test_sampled = Input_test[p]
Input_test_sampled_normalized = Input_test_sampled.astype('float16')/255
#予測画像
Input_test_sampled_normalized_pred = model.predict(Input_test_sampled_normalized)
x_test1 = Input_test_sampled_normalized_pred*255
Input_test_sampled_pred = x_test1.astype('uint8')
fig, axes = plt.subplots(3, 10, figsize=(20, 4))
plt.subplots_adjust(hspace=0.4)
for i, label in enumerate(y_test[p]):
input_img = Input_test_sampled[i]
axes[0][i].imshow(input_img, cmap = cm.gray_r)
axes[0][i].axis('off')
subtitle0="input"+str(label)
axes[0][i].set_title(subtitle0, fontsize=10.5)
pred_img = Input_test_sampled_pred[i]
axes[1][i].imshow(pred_img, cmap = cm.gray_r)
axes[1][i].axis('off')
subtitle1="output"+str(label)
axes[1][i].set_title(subtitle1, fontsize=10.5)
original_img = x_test_sampled[i]
axes[2][i].imshow(original_img, cmap = cm.gray_r)
axes[2][i].axis('off')
subtitle2="original"+str(label)
axes[2][i].set_title(subtitle2, fontsize=10.5)
plt.show()
#5. PSNRグラフと損失グラフ
#PSNR Graph
plt.rc('font', family = 'serif')
fig = plt.figure(figsize = (13, 6))
plt.subplot(1, 2, 1)
plt.plot(hist.history['psnr'], label = 'Train psnr', color = 'red')
plt.plot(hist.history['val_psnr'], label = 'Test psnr', color = 'm')
plt.title('Model PSNR')
plt.xlabel('Epochs')
plt.ylabel('PSNR')
plt.legend(bbox_to_anchor = (0.7, 0.5), loc = 'center left', borderaxespad = 0, fontsize = 8)
#Loss Graph
plt.rc('font', family = 'serif')
plt.subplot(1, 2, 2)
plt.plot(hist.history['loss'], label = 'Train Loss', color = 'blue')
plt.plot(hist.history['val_loss'], label = 'Test Loss', color = 'c')
plt.title('Model Loss')
plt.xlabel('Epochs')
plt.ylabel('loss')
plt.legend(bbox_to_anchor = (0.7, 0.5), loc = 'center left', borderaxespad = 0, fontsize = 8)
plt.show()