0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【超解像】SwinIRをクラスで実行できるようにする

Posted at

はじめに

コマンドラインで動作するようになっているmain_test_swinir.pyをclassに変換するのが案外手間だったので、残しておきます。

SwinIRのGithub

コマンドライン実行 -> Class実行へ

python3 main_test_swinir.py --task real_sr --scale 4 --model_path model_zoo/swinir/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth --folder_lq testsets/

上記の結果が得られれば良かったので、classでは不要な引数を削除しています。

#変更後のコード

import cv2
import glob
import numpy as np
from collections import OrderedDict
import os
import torch
import requests

from models.network_swinir import SwinIR as net
from utils import util_calculate_psnr_ssim as util

class SwinIRInference:
    def __init__(self, model_path, scale=4, large_model=False, target_folder="your/target/folder"):
        self.model_path = model_path
        self.scale = scale
        self.large_model = large_model
        self.target_folder = target_folder
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = self.define_model()
        self.results = OrderedDict(psnr=[], ssim=[], psnr_y=[], ssim_y=[], psnrb=[], psnrb_y=[])

    def define_model(self):
        if not os.path.exists(self.model_path):
            self.download_model()
        model = self._initialize_model_structure()
        model = self._load_pretrained_model(model)
        return model.to(self.device).eval()

    def download_model(self):
        os.makedirs(os.path.dirname(self.model_path), exist_ok=True)
        url = f'https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/{os.path.basename(self.model_path)}'
        print(f'Downloading model from {url}')
        r = requests.get(url, allow_redirects=True)
        with open(self.model_path, 'wb') as f:
            f.write(r.content)

    def _initialize_model_structure(self):
        if self.large_model:
            model = net(upscale=self.scale, in_chans=3, img_size=64, window_size=8, img_range=1.,
                        depths=[6] * 9, embed_dim=240, num_heads=[8] * 9, mlp_ratio=2,
                        upsampler='nearest+conv', resi_connection='3conv')
        else:
            model = net(upscale=self.scale, in_chans=3, img_size=64, window_size=8, img_range=1.,
                        depths=[6] * 6, embed_dim=180, num_heads=[6] * 6, mlp_ratio=2,
                        upsampler='nearest+conv', resi_connection='1conv')
        return model

    def _load_pretrained_model(self, model):
        pretrained_model = torch.load(self.model_path)
        param_key_g = 'params_ema'
        model.load_state_dict(pretrained_model[param_key_g] if param_key_g in pretrained_model else pretrained_model, strict=True)
        return model

    def setup_folders(self):
        save_dir = f'results/swinir_real_sr_x{self.scale}' + ('_large' if self.large_model else '')
        os.makedirs(save_dir, exist_ok=True)
        return self.target_folder, save_dir

    def get_image_pair(self, path):
        imgname, _ = os.path.splitext(os.path.basename(path))
        img_lq = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
        img_gt = None  # GT is None for real-world image super-resolution tasks
        return imgname, img_lq, img_gt

    def run_inference(self):
        folder, save_dir = self.setup_folders()
        window_size = 8

        for idx, path in enumerate(sorted(glob.glob(os.path.join(folder, '*')))):
            imgname, img_lq, img_gt = self.get_image_pair(path)
            output = self.process_image(img_lq, window_size)
            self.save_output(output, save_dir, imgname)
            self.evaluate_results(output, img_gt, imgname, idx, window_size)

        self.summarize_results()

    def process_image(self, img_lq, window_size):
        img_lq = np.transpose(img_lq[..., [2, 1, 0]] if img_lq.shape[2] == 3 else img_lq, (2, 0, 1))
        img_lq = torch.from_numpy(img_lq).float().unsqueeze(0).to(self.device)
        h_pad = (img_lq.size(2) // window_size + 1) * window_size - img_lq.size(2)
        w_pad = (img_lq.size(3) // window_size + 1) * window_size - img_lq.size(3)
        img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :img_lq.size(2) + h_pad, :]
        img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :img_lq.size(3) + w_pad]

        with torch.no_grad():
            output = self.model(img_lq)[..., :img_lq.size(2) * self.scale, :img_lq.size(3) * self.scale]
        return output

    def save_output(self, output, save_dir, imgname):
        output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
        output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) if output.ndim == 3 else output
        output = (output * 255.0).round().astype(np.uint8)
        cv2.imwrite(f'{save_dir}/{imgname}_SwinIR.png', output)

    def evaluate_results(self, output, img_gt, imgname, idx, border):
        if img_gt is not None:
            img_gt = (img_gt * 255.0).round().astype(np.uint8)
            psnr = util.calculate_psnr(output, img_gt, crop_border=border)
            ssim = util.calculate_ssim(output, img_gt, crop_border=border)
            self.results['psnr'].append(psnr)
            self.results['ssim'].append(ssim)

            print(f'Testing {idx} {imgname} - PSNR: {psnr:.2f} dB; SSIM: {ssim:.4f}')
        else:
            print(f'Testing {idx} {imgname}')

    def summarize_results(self):
        if self.results['psnr']:
            ave_psnr = np.mean(self.results['psnr'])
            ave_ssim = np.mean(self.results['ssim'])
            print(f'\n-- Average PSNR: {ave_psnr:.2f} dB; SSIM: {ave_ssim:.4f}')

if __name__ == '__main__':
    model_path = "model_zoo/swinir/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth"
    swinir = SwinIRInference(model_path=model_path)
    swinir.run_inference()

お役に立てれば幸いです。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?