2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

web-ui用のreal-esrganをReal-ESRGANで動かす

Posted at

Real-ESRGANを動かす際に色々とトラブリましたのでメモも兼ねて残します。

今回使わせて頂いたライブラリはxinntao/Real-ESRGAN
モデルは4x_fatal_Anime_50000_Gをお借りしました
動作環境はGoogle Colaboratory,Python 3.10.12、
アクセレータはT4 gpuです

この記事を見られている方はReal-ESRGANについて既にご存知かと思いますのでこちらの説明は省きます。

まずはpackageのインストールから

install dependency
%pip install basicsr
%pip install git+https://github.com/xinntao/Real-ESRGAN.git
%pip install huggingface_hub

続いてモデルのダウンロード

download model
from huggingface_hub import hf_hub_download

filepath = hf_hub_download("Akumetsu971/SD_Anime_Futuristic_Armor", filename="4x_fatal_Anime_500000_G.pth")

続いてモデルの読み込みです
real-esrganのgitの方ではshellでの実行を進められていましたが、
colab環境ということもあってpythonで動かします
テストスクリプトから必要そうな部分だけ抜き出して使わせて貰います

setup model
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
__upscaler = RealESRGANer(
        scale=netscale,
        model_path=filepath,
        dni_weight=None,
        model=model,
        tile=0,
        tile_pad=10,
        pre_pad=0,
        half=True,
        gpu_id=None)

まずモデルの読み込みだけでも、と思ったのですがいきなりエラーが

ModuleNotFoundError                       Traceback (most recent call last)
<ipython-input-6-92de257526c4> in <cell line: 1>()
----> 1 from realesrgan import RealESRGANer
      2 from basicsr.archs.rrdbnet_arch import RRDBNet
      3 model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
      4 netscale = 4
      5 __upscampler = RealESRGANer(
------------------ 7 frames ----------------------------
/usr/local/lib/python3.10/dist-packages/basicsr/data/realesrgan_dataset.py in <module>
      9 from torch.utils import data as data
     10 
---> 11 from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
     12 from basicsr.data.transforms import augment
     13 from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor

こちらについて調べたところ、basicrのパッケージがアップデートされておらず、torchvisonとのバージョンの違いで食い違ってしまっていたとこのこと

これに関しては該当箇所を修正すれば良いとのことでした 参考

repare package
!sed -i 's/from torchvision.transforms.functional_tensor import rgb_to_grayscale/from torchvision.transforms.functional import rgb_to_grayscale/' '/usr/local/lib/python3.10/dist-packages/basicsr/data/degradations.py'

パッケージが保存される箇所はvenvなどによっても異なります
ローカルで動かす際には直接ファイルを編集してください

これを実行した上でセッションを再起動し、もう一度モデルを読み込むと、
(実際はセッションを再起動時にfilepathが消えますが、hf_hub_downloadを行なって貰えれば大丈夫です)

KeyError                                  Traceback (most recent call last)
<ipython-input-3-92de257526c4> in <cell line: 5>()
      3 model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
      4 netscale = 4
----> 5 __upscampler = RealESRGANer(
      6         scale=netscale,
      7         model_path=filepath,

/usr/local/lib/python3.10/dist-packages/realesrgan/utils.py in __init__(self, scale, model_path, dni_weight, model, tile, tile_pad, pre_pad, half, device, gpu_id)
     68         else:
     69             keyname = 'params'
---> 70         model.load_state_dict(loadnet[keyname], strict=True)
     71 
     72         model.eval()

KeyError: 'params'

またエラーが出ます

おそらくpthファイルのkeyネームのエラーみたいです。
Real-ESRGANの方で公開されている訓練ずみモデルを開いてみると、

pth file
!wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth
import torch
# Load the .pth file as a tensor
state_dict_o = torch.load("4x_fatal_Anime_500000_G.pth")

for i in state_dict_o.key():
    print(i)
# >> "params_ema"

for i in state_dict_o["params_ema"].keys():
  print(i)

中身はこんな感じになってました。

conv_first.weight
conv_first.bias
body.0.rdb1.conv1.weight
body.0.rdb1.conv1.bias
body.0.rdb1.conv2.weight
body.0.rdb1.conv2.bias
body.0.rdb1.conv3.weight
body.0.rdb1.conv3.bias
body.0.rdb1.conv4.weight
body.0.rdb1.conv4.bias
...

一方ダウンロードしたモデルは、

pth file
import torch
# Load the .pth file as a tensor
state_dict = torch.load(filename)

for i in state_dict.key():
    print(i)

この中身は、

model.0.weight
model.0.bias
model.1.sub.0.RDB1.conv1.0.weight
model.1.sub.0.RDB1.conv1.0.bias
model.1.sub.0.RDB1.conv2.0.weight
model.1.sub.0.RDB1.conv2.0.bias
model.1.sub.0.RDB1.conv3.0.weight
...

実際にこの二つを比べてみると畳み込み層などレイヤーの構造そのものは全く同じみたいです
ただパラメータの名前の付け方が違うらしく、このため読み込めなかったみたいです

そのため、パラメータの名前を編集してあげると、

edit state dict key name
import re
import torch

# Load the .pth file as a tensor
state_dict = torch.load(filepath)
sdict = {}
changename = {
    "model.0": "conv_first",
    "model.1.sub.23": "conv_body",
    "model.1.sub": "body",
    "model.3": "conv_up1",
    "model.6": "conv_up2",
    "model.8": "conv_hr",
    "model.10": "conv_last",
}
opts = "(" + "|".join(changename.keys()) + ")"
regex = re.compile(opts + r"(.*?)(\.0)?(\.weight|\.bias)")
for i in state_dict.keys():
    flag, mid, _, suf = regex.match(i).groups()
    weightname = changename[flag] + mid.lower() + suf
    sdict[weightname] = state_dict[i]
sdict = {"params_ema": sdict}

filepath = "4x_fatal_Anime_500000_G_edit_keyname.pth"
torch.save(sdict, filepath)

こちらで読み込んでみましょう

load model2
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
__upscaler = RealESRGANer(
        scale=netscale,
        model_path="4x_fatal_Anime_500000_G_edit_keyname.pth",
        dni_weight=None,
        model=model,
        tile=0,
        tile_pad=10,
        pre_pad=0,
        half=True,
        gpu_id=None)

__upscaler
# >> <realesrgan.utils.RealESRGANer at 0x...>

今度はちゃんと読み込めたみたいです

実際にアップスケールしてみましょう
絵に関してはgeminiに作成してもらいました
まず絵を1/4に圧縮します

compress img
# prompt: open dancing_croco.png and compress into one quarter
from PIL import Image

image_path = "/content/dancing_croco.png"
scale_factor = 0.25  # 1/4 of the original size
img = Image.open(image_path)
width, height = img.size
new_width = int(width * scale_factor)
new_height = int(height * scale_factor)
img_resized = img.resize((new_width, new_height), Image.LANCZOS)  # Use a high-quality resampling filter

img_resized

dancing_croco.png

次に逆にupscaleしてみましょう

upscale image
img = pil2cv(img_resized)
        
with torch.no_grad():
  output,_ = __upscaler.enhance(img, outscale=4)

output = cv2pil(output)

output

upscaled.png

模様に巻き込まれてビルが歪んじゃってますね、
背景は除去してからupscaleしたほうが良いかもしれません
また処理自体はかなり高速なので余程バッチを回すとかでなければおそらくhalfは不要です
他にも設定は色々とあるみたいなのでいじってみてください

upscaleする事には見事成功しました!
お疲れ様です

参考
https://github.com/xinntao/Real-ESRGAN

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?