はじめに
コマンドラインで動作するようになっているmain_test_swinir.pyをclassに変換するのが案外手間だったので、残しておきます。
SwinIRのGithub
コマンドライン実行 -> Class実行へ
python3 main_test_swinir.py --task real_sr --scale 4 --model_path model_zoo/swinir/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth --folder_lq testsets/
上記の結果が得られれば良かったので、classでは不要な引数を削除しています。
#変更後のコード
import cv2
import glob
import numpy as np
from collections import OrderedDict
import os
import torch
import requests
from models.network_swinir import SwinIR as net
from utils import util_calculate_psnr_ssim as util
class SwinIRInference:
def __init__(self, model_path, scale=4, large_model=False, target_folder="your/target/folder"):
self.model_path = model_path
self.scale = scale
self.large_model = large_model
self.target_folder = target_folder
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = self.define_model()
self.results = OrderedDict(psnr=[], ssim=[], psnr_y=[], ssim_y=[], psnrb=[], psnrb_y=[])
def define_model(self):
if not os.path.exists(self.model_path):
self.download_model()
model = self._initialize_model_structure()
model = self._load_pretrained_model(model)
return model.to(self.device).eval()
def download_model(self):
os.makedirs(os.path.dirname(self.model_path), exist_ok=True)
url = f'https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/{os.path.basename(self.model_path)}'
print(f'Downloading model from {url}')
r = requests.get(url, allow_redirects=True)
with open(self.model_path, 'wb') as f:
f.write(r.content)
def _initialize_model_structure(self):
if self.large_model:
model = net(upscale=self.scale, in_chans=3, img_size=64, window_size=8, img_range=1.,
depths=[6] * 9, embed_dim=240, num_heads=[8] * 9, mlp_ratio=2,
upsampler='nearest+conv', resi_connection='3conv')
else:
model = net(upscale=self.scale, in_chans=3, img_size=64, window_size=8, img_range=1.,
depths=[6] * 6, embed_dim=180, num_heads=[6] * 6, mlp_ratio=2,
upsampler='nearest+conv', resi_connection='1conv')
return model
def _load_pretrained_model(self, model):
pretrained_model = torch.load(self.model_path)
param_key_g = 'params_ema'
model.load_state_dict(pretrained_model[param_key_g] if param_key_g in pretrained_model else pretrained_model, strict=True)
return model
def setup_folders(self):
save_dir = f'results/swinir_real_sr_x{self.scale}' + ('_large' if self.large_model else '')
os.makedirs(save_dir, exist_ok=True)
return self.target_folder, save_dir
def get_image_pair(self, path):
imgname, _ = os.path.splitext(os.path.basename(path))
img_lq = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
img_gt = None # GT is None for real-world image super-resolution tasks
return imgname, img_lq, img_gt
def run_inference(self):
folder, save_dir = self.setup_folders()
window_size = 8
for idx, path in enumerate(sorted(glob.glob(os.path.join(folder, '*')))):
imgname, img_lq, img_gt = self.get_image_pair(path)
output = self.process_image(img_lq, window_size)
self.save_output(output, save_dir, imgname)
self.evaluate_results(output, img_gt, imgname, idx, window_size)
self.summarize_results()
def process_image(self, img_lq, window_size):
img_lq = np.transpose(img_lq[..., [2, 1, 0]] if img_lq.shape[2] == 3 else img_lq, (2, 0, 1))
img_lq = torch.from_numpy(img_lq).float().unsqueeze(0).to(self.device)
h_pad = (img_lq.size(2) // window_size + 1) * window_size - img_lq.size(2)
w_pad = (img_lq.size(3) // window_size + 1) * window_size - img_lq.size(3)
img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :img_lq.size(2) + h_pad, :]
img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :img_lq.size(3) + w_pad]
with torch.no_grad():
output = self.model(img_lq)[..., :img_lq.size(2) * self.scale, :img_lq.size(3) * self.scale]
return output
def save_output(self, output, save_dir, imgname):
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) if output.ndim == 3 else output
output = (output * 255.0).round().astype(np.uint8)
cv2.imwrite(f'{save_dir}/{imgname}_SwinIR.png', output)
def evaluate_results(self, output, img_gt, imgname, idx, border):
if img_gt is not None:
img_gt = (img_gt * 255.0).round().astype(np.uint8)
psnr = util.calculate_psnr(output, img_gt, crop_border=border)
ssim = util.calculate_ssim(output, img_gt, crop_border=border)
self.results['psnr'].append(psnr)
self.results['ssim'].append(ssim)
print(f'Testing {idx} {imgname} - PSNR: {psnr:.2f} dB; SSIM: {ssim:.4f}')
else:
print(f'Testing {idx} {imgname}')
def summarize_results(self):
if self.results['psnr']:
ave_psnr = np.mean(self.results['psnr'])
ave_ssim = np.mean(self.results['ssim'])
print(f'\n-- Average PSNR: {ave_psnr:.2f} dB; SSIM: {ave_ssim:.4f}')
if __name__ == '__main__':
model_path = "model_zoo/swinir/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth"
swinir = SwinIRInference(model_path=model_path)
swinir.run_inference()
お役に立てれば幸いです。