相変わらず,ライブラリを使えば簡単にできる.一方で,imagenetなどのデータと比較すると運転データはあまり画質が上がっているとはいえず,用いるデータにドメインを合わせて,自分のデータを用意して学習しないといけないのかな,とも思った.
!git clone https://github.com/krasserm/super-resolution.git
cd /content/drive/MyDrive/Lectures/DSS/super-resolution/
!pip install tensorflow==2.5 -q
import os
import matplotlib.pyplot as plt
from data import DIV2K
from model.srgan import generator, discriminator
from train import SrganTrainer, SrganGeneratorTrainer
%matplotlib inline
img_path = "/content/drive/MyDrive/Lectures/DSS/data_blurred/frame_100/demo/frame_00000.jpg"
# Location of model weights (needed for demo)
weights_dir = 'weights/srgan'
weights_file = lambda filename: os.path.join(weights_dir, filename)
os.makedirs(weights_dir, exist_ok=True)
pre_generator = generator()
gan_generator = generator()
pre_generator.load_weights(weights_file('pre_generator.h5'))
gan_generator.load_weights(weights_file('gan_generator.h5'))
from model import resolve_single
from utils import load_image
def resolve_and_plot(lr_image_path):
lr = load_image(lr_image_path)
pre_sr = resolve_single(pre_generator, lr)
gan_sr = resolve_single(gan_generator, lr)
plt.figure(figsize=(20, 20))
images = [lr, pre_sr, gan_sr]
titles = ['LR', 'SR (PRE)', 'SR (GAN)']
positions = [1, 3, 4]
for i, (img, title, pos) in enumerate(zip(images, titles, positions)):
plt.subplot(2, 2, pos)
plt.imshow(img)
plt.title(title)
plt.xticks([])
plt.yticks([])
resolve_and_plot(img_path)