0
2

More than 1 year has passed since last update.

CIFAR-10の高解像度化

Last updated at Posted at 2021-12-31

1. ライブラリとPSNR指標

ピーク信号対雑音比(PSNR)を「解像度の高さ」の指標とする。
この値が20以上くらいになると比較的、解像度が良いとされている。
画像に対しては概ね 30dB 以上であれば元の画質を高品質に保っている。

import numpy as np
import cv2

import keras
from keras.datasets import cifar10
from sklearn.model_selection import train_test_split
from keras.models import Sequential
from keras import backend as K
from keras import Input, Model, callbacks, layers, models
from keras.layers import (Add, BatchNormalization, Conv2D, Dense, LeakyReLU, MaxPooling2D, UpSampling2D)
from keras.layers.core import Activation, Dense
from keras.optimizers import Adam

def psnr(y_true, y_pred):
  return -10*K.log(K.mean(K.flatten((y_true - y_pred))**2))/np.log(10)

2. 学習データとテストデータの準備

学習データとして、16×16ピクセルの画像を用意する。

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

#荒い画像データの用意
input_train=[]
input_test=[]

for i in range(50000):

  sample1=x_train[i].copy()
  sample1=cv2.resize(sample1, dsize=(16,16))
  sample1=cv2.resize(sample1, dsize=(32,32), interpolation=cv2.INTER_NEAREST)
  input_train.append(sample1)

for j in range(10000):
  sample2=x_test[j].copy()
  sample2=cv2.resize(sample2, dsize=(16,16))
  sample2=cv2.resize(sample2, dsize=(32,32), interpolation=cv2.INTER_NEAREST)
  input_test.append(sample2)

#データを正規化
x_train_normalized = x_train.astype('float16')/255
x_test_normalized = x_test.astype('float16')/255

Input_train = np.asarray(input_train)
Input_test = np.asarray(input_test)
Input_train_normalized = Input_train.astype('float16')/255
Input_test_normalized = Input_test.astype('float16')/255

3. モデルの定義

inputs = Input((32, 32, 3))

X = inputs
shortcut = X

#メイン
X = Conv2D(filters=16, kernel_size=(3,3), activation='relu', padding='same')(inputs)
X = Conv2D(filters=32, kernel_size=(3,3), activation='relu', padding='same')(X)
X = Conv2D(filters=64, kernel_size=(3,3), activation='relu', padding='same')(X)
X = Conv2D(filters=3, kernel_size=(1,1))(X)

#ショートカット接続
X = Add()([X, shortcut])

outputs = X

#モデル
model = Model(inputs=[inputs], outputs=[outputs])

model.summary()

model.compile(optimizer = Adam(learning_rate = 0.0002), loss = 'mse', metrics = [psnr])

fit_callbacs = [callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, mode = 'min')]

hist = model.fit(Input_train_normalized, x_train_normalized, epochs=30, batch_size=32, shuffle = True,
                 validation_data = (Input_test_normalized, x_test_normalized), callbacks = fit_callbacs, verbose=1)

val_loss, val_psnr = model.evaluate(Input_test_normalized, x_test_normalized, verbose=1)

4. 画像の表示

import matplotlib.pyplot as plt
import matplotlib.cm as cm

p = np.random.randint(0, len(x_test), 10)

#オリジナル画像
x_test_sampled = x_test[p]
y_test_sample_lable = y_test[p]

#画質が荒い入力画像
Input_test_sampled = Input_test[p]
Input_test_sampled_normalized = Input_test_sampled.astype('float16')/255

#予測画像
Input_test_sampled_normalized_pred = model.predict(Input_test_sampled_normalized)
x_test1 = Input_test_sampled_normalized_pred*255
Input_test_sampled_pred = x_test1.astype('uint8')

fig, axes = plt.subplots(3, 10, figsize=(20, 4))
plt.subplots_adjust(hspace=0.4)

for i, label in enumerate(y_test[p]):

        input_img = Input_test_sampled[i]
        axes[0][i].imshow(input_img, cmap = cm.gray_r)
        axes[0][i].axis('off')
        subtitle0="input"+str(label)
        axes[0][i].set_title(subtitle0, fontsize=10.5)

        pred_img = Input_test_sampled_pred[i]
        axes[1][i].imshow(pred_img, cmap = cm.gray_r)
        axes[1][i].axis('off')
        subtitle1="output"+str(label)
        axes[1][i].set_title(subtitle1, fontsize=10.5)

        original_img = x_test_sampled[i]
        axes[2][i].imshow(original_img, cmap = cm.gray_r)
        axes[2][i].axis('off')
        subtitle2="original"+str(label)
        axes[2][i].set_title(subtitle2, fontsize=10.5)

plt.show()

高解像度.png

5. PSNRグラフと損失グラフ

#PSNR Graph
plt.rc('font', family = 'serif')
fig = plt.figure(figsize = (13, 6))
plt.subplot(1, 2, 1)
plt.plot(hist.history['psnr'], label = 'Train psnr', color = 'red')
plt.plot(hist.history['val_psnr'], label = 'Test psnr', color = 'm')
plt.title('Model PSNR')
plt.xlabel('Epochs')
plt.ylabel('PSNR')
plt.legend(bbox_to_anchor = (0.7, 0.5), loc = 'center left', borderaxespad = 0, fontsize = 8)

#Loss Graph
plt.rc('font', family = 'serif')
plt.subplot(1, 2, 2)
plt.plot(hist.history['loss'], label = 'Train Loss', color = 'blue')
plt.plot(hist.history['val_loss'], label = 'Test Loss', color = 'c')
plt.title('Model Loss')
plt.xlabel('Epochs')
plt.ylabel('loss')
plt.legend(bbox_to_anchor = (0.7, 0.5), loc = 'center left', borderaxespad = 0, fontsize = 8)

plt.show()

高解像度_学習率png.png

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2