LoginSignup
2
1

More than 1 year has passed since last update.

BasicSRでLPIPSを計算したい!

Last updated at Posted at 2023-03-16

初めに

超解像モデルの訓練とテストを手軽に行えるツールボックスといえば、BasicSR ですよね!

研究をする中で超解像の評価指標として有名なLPIPSを計算したいときが、週に1回ありますよね
しかし、BasicSRには実装されていない…

てことで,実装する方法を共有します.以下のコードはconda環境で実行しています.

BasicSRのインストール

まずは,lpipsのインストール

lpipsのインストール
    pip install lpips

次に,はBasicSRのインストール方法.方法は単純で以下のコードをターミナルで実行する

basicsrのインストール
    pip install basicsr

すると,ここにBasicSRのパッケージがインストールされます

basicsrの場所
    /home/<usr>/anaconda3/lib/python3.9/site-packages/basicsr

LPIPSの追加

LPIPSを計算するためのコードを追加するために,basicsr下のmetricフォルダを編集します.やり方は以下のコードをコピペするだけ!

__init__.py
from copy import deepcopy

from basicsr.utils.registry import METRIC_REGISTRY
from .niqe import calculate_niqe
from .psnr_ssim import calculate_psnr, calculate_ssim
from .lpips import calculate_lpips

__all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe', 'calculate_lpips']

def calculate_metric(data, opt):
    """Calculate metric from data and options.

    Args:
        opt (dict): Configuration. It must contain:
            type (str): Model type.
    """
    opt = deepcopy(opt)
    metric_type = opt.pop('type')
    metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
    return metric

lpips.py
from torchvision.transforms.functional import normalize
from basicsr.utils import img2tensor
import lpips
import numpy as np
from basicsr.metrics.metric_util import reorder_image, to_y_channel
from basicsr.utils.registry import METRIC_REGISTRY
import torch

@METRIC_REGISTRY.register()
def calculate_lpips(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs):

    """Calculate LPIPS.
    Ref: https://github.com/xinntao/BasicSR/pull/367
    Args:
        img (ndarray): Images with range [0, 255].
        img2 (ndarray): Images with range [0, 255].
        crop_border (int): Cropped pixels in each edge of an image. These
            pixels are not involved in the PSNR calculation.
        input_order (str): Whether the input order is 'HWC' or 'CHW'.
            Default: 'HWC'.
        test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
    Returns:
        float: LPIPS result.
    """
    assert img.shape == img2.shape, (f'Image shapes are differnet: {img.shape}, {img2.shape}.')
    if input_order not in ['HWC', 'CHW']:
        raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
    img = reorder_image(img, input_order=input_order)
    img2 = reorder_image(img2, input_order=input_order)
    img = img.astype(np.float64)
    img2 = img2.astype(np.float64)
    if crop_border != 0:
        img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
        img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
    if test_y_channel:
        img = to_y_channel(img)
        img2 = to_y_channel(img2)
    # start calculating LPIPS metrics
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    loss_fn_vgg = lpips.LPIPS(net='vgg', verbose=False).to(DEVICE)  # RGB, normalized to [-1,1]
    mean = [0.5, 0.5, 0.5]
    std = [0.5, 0.5, 0.5]
    img_gt = img2 / 255.
    img_restored = img / 255.
    img_gt, img_restored = img2tensor([img_gt, img_restored], bgr2rgb=True, float32=True)
    # norm to [-1, 1]
    normalize(img_gt, mean, std, inplace=True)
    normalize(img_restored, mean, std, inplace=True)
    # calculate lpips
    img_gt = img_gt.to(DEVICE)
    img_restored = img_restored.to(DEVICE)
    loss_fn_vgg.eval()
    lpips_val = loss_fn_vgg(img_restored.unsqueeze(0), img_gt.unsqueeze(0))
    return lpips_val.detach().cpu().numpy().mean()
2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1