Real-ESRGANを動かす際に色々とトラブリましたのでメモも兼ねて残します。
今回使わせて頂いたライブラリはxinntao/Real-ESRGAN
モデルは4x_fatal_Anime_50000_Gをお借りしました
動作環境はGoogle Colaboratory,Python 3.10.12、
アクセレータはT4 gpuです
この記事を見られている方はReal-ESRGANについて既にご存知かと思いますのでこちらの説明は省きます。
まずはpackageのインストールから
%pip install basicsr
%pip install git+https://github.com/xinntao/Real-ESRGAN.git
%pip install huggingface_hub
続いてモデルのダウンロード
from huggingface_hub import hf_hub_download
filepath = hf_hub_download("Akumetsu971/SD_Anime_Futuristic_Armor", filename="4x_fatal_Anime_500000_G.pth")
続いてモデルの読み込みです
real-esrganのgitの方ではshellでの実行を進められていましたが、
colab環境ということもあってpythonで動かします
テストスクリプトから必要そうな部分だけ抜き出して使わせて貰います
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
__upscaler = RealESRGANer(
scale=netscale,
model_path=filepath,
dni_weight=None,
model=model,
tile=0,
tile_pad=10,
pre_pad=0,
half=True,
gpu_id=None)
まずモデルの読み込みだけでも、と思ったのですがいきなりエラーが
ModuleNotFoundError Traceback (most recent call last)
<ipython-input-6-92de257526c4> in <cell line: 1>()
----> 1 from realesrgan import RealESRGANer
2 from basicsr.archs.rrdbnet_arch import RRDBNet
3 model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
4 netscale = 4
5 __upscampler = RealESRGANer(
------------------ 7 frames ----------------------------
/usr/local/lib/python3.10/dist-packages/basicsr/data/realesrgan_dataset.py in <module>
9 from torch.utils import data as data
10
---> 11 from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
12 from basicsr.data.transforms import augment
13 from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
こちらについて調べたところ、basicrのパッケージがアップデートされておらず、torchvisonとのバージョンの違いで食い違ってしまっていたとこのこと
これに関しては該当箇所を修正すれば良いとのことでした 参考
!sed -i 's/from torchvision.transforms.functional_tensor import rgb_to_grayscale/from torchvision.transforms.functional import rgb_to_grayscale/' '/usr/local/lib/python3.10/dist-packages/basicsr/data/degradations.py'
パッケージが保存される箇所はvenvなどによっても異なります
ローカルで動かす際には直接ファイルを編集してください
これを実行した上でセッションを再起動し、もう一度モデルを読み込むと、
(実際はセッションを再起動時にfilepathが消えますが、hf_hub_downloadを行なって貰えれば大丈夫です)
KeyError Traceback (most recent call last)
<ipython-input-3-92de257526c4> in <cell line: 5>()
3 model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
4 netscale = 4
----> 5 __upscampler = RealESRGANer(
6 scale=netscale,
7 model_path=filepath,
/usr/local/lib/python3.10/dist-packages/realesrgan/utils.py in __init__(self, scale, model_path, dni_weight, model, tile, tile_pad, pre_pad, half, device, gpu_id)
68 else:
69 keyname = 'params'
---> 70 model.load_state_dict(loadnet[keyname], strict=True)
71
72 model.eval()
KeyError: 'params'
またエラーが出ます
おそらくpthファイルのkeyネームのエラーみたいです。
Real-ESRGANの方で公開されている訓練ずみモデルを開いてみると、
!wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth
import torch
# Load the .pth file as a tensor
state_dict_o = torch.load("4x_fatal_Anime_500000_G.pth")
for i in state_dict_o.key():
print(i)
# >> "params_ema"
for i in state_dict_o["params_ema"].keys():
print(i)
中身はこんな感じになってました。
conv_first.weight
conv_first.bias
body.0.rdb1.conv1.weight
body.0.rdb1.conv1.bias
body.0.rdb1.conv2.weight
body.0.rdb1.conv2.bias
body.0.rdb1.conv3.weight
body.0.rdb1.conv3.bias
body.0.rdb1.conv4.weight
body.0.rdb1.conv4.bias
...
一方ダウンロードしたモデルは、
import torch
# Load the .pth file as a tensor
state_dict = torch.load(filename)
for i in state_dict.key():
print(i)
この中身は、
model.0.weight
model.0.bias
model.1.sub.0.RDB1.conv1.0.weight
model.1.sub.0.RDB1.conv1.0.bias
model.1.sub.0.RDB1.conv2.0.weight
model.1.sub.0.RDB1.conv2.0.bias
model.1.sub.0.RDB1.conv3.0.weight
...
実際にこの二つを比べてみると畳み込み層などレイヤーの構造そのものは全く同じみたいです
ただパラメータの名前の付け方が違うらしく、このため読み込めなかったみたいです
そのため、パラメータの名前を編集してあげると、
import re
import torch
# Load the .pth file as a tensor
state_dict = torch.load(filepath)
sdict = {}
changename = {
"model.0": "conv_first",
"model.1.sub.23": "conv_body",
"model.1.sub": "body",
"model.3": "conv_up1",
"model.6": "conv_up2",
"model.8": "conv_hr",
"model.10": "conv_last",
}
opts = "(" + "|".join(changename.keys()) + ")"
regex = re.compile(opts + r"(.*?)(\.0)?(\.weight|\.bias)")
for i in state_dict.keys():
flag, mid, _, suf = regex.match(i).groups()
weightname = changename[flag] + mid.lower() + suf
sdict[weightname] = state_dict[i]
sdict = {"params_ema": sdict}
filepath = "4x_fatal_Anime_500000_G_edit_keyname.pth"
torch.save(sdict, filepath)
こちらで読み込んでみましょう
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
__upscaler = RealESRGANer(
scale=netscale,
model_path="4x_fatal_Anime_500000_G_edit_keyname.pth",
dni_weight=None,
model=model,
tile=0,
tile_pad=10,
pre_pad=0,
half=True,
gpu_id=None)
__upscaler
# >> <realesrgan.utils.RealESRGANer at 0x...>
今度はちゃんと読み込めたみたいです
実際にアップスケールしてみましょう
絵に関してはgeminiに作成してもらいました
まず絵を1/4に圧縮します
# prompt: open dancing_croco.png and compress into one quarter
from PIL import Image
image_path = "/content/dancing_croco.png"
scale_factor = 0.25 # 1/4 of the original size
img = Image.open(image_path)
width, height = img.size
new_width = int(width * scale_factor)
new_height = int(height * scale_factor)
img_resized = img.resize((new_width, new_height), Image.LANCZOS) # Use a high-quality resampling filter
img_resized
次に逆にupscaleしてみましょう
img = pil2cv(img_resized)
with torch.no_grad():
output,_ = __upscaler.enhance(img, outscale=4)
output = cv2pil(output)
output
模様に巻き込まれてビルが歪んじゃってますね、
背景は除去してからupscaleしたほうが良いかもしれません
また処理自体はかなり高速なので余程バッチを回すとかでなければおそらくhalfは不要です
他にも設定は色々とあるみたいなのでいじってみてください
upscaleする事には見事成功しました!
お疲れ様です
参考
https://github.com/xinntao/Real-ESRGAN