LoginSignup
0
0

spandrelから超解像・明るさ補正などのモデルをCoreMLに変換する

Last updated at Posted at 2024-06-16

いろんなモデルを同じ手順で変換できる

機械学習モデルをiOSで使うためにCoreMLに変換するとなると、これまではモデルのリポジトリからモデルアーキテクチャをイニシャライズして重みを読み込ませないといけなかったが、

spandrelはpytorchのチェックポイントファイル(pth)から、モデルの構造まで復元して使えるライブラリで、超解像やインペインティングなど様々なモデルをサポートしている。
必要なのは重み(チェックポイントファイル)のみです。

これらをCoreMLに変換してiOSやmacOSオンデバイスで使えたら便利、っていうことで変換していく。

今回は低照度画像強化のRetinexFormerでやってみましょう。

darkroom.jpgdarkroome.jpg

例:RetinexFormerの変換手順

インストール

pipでインストール。

pip install spandrel spandrel_extra_arches
# git clone https://github.com/chaiNNer-org/spandrel.git
# import shutil
# shutil.copytree("spandrel/libs/spandrel/spandrel/","/usr/local/lib/python3.10/dist-packages/spandrel")

基本はpipのみでいけるが、なぜかpipでインストールすると、モデルアーキテクチャが入ってないやつがあったりするので、もし後のモデル初期化でうまくいかなかったらリポジトリをクローンしてパッケージに置き換えよう。

チェックポイントファイルのダウンロード

チェックポイントファイルはspandrelのreadmeのリンクなどからダウンロードできる。
https://github.com/chaiNNer-org/spandrel

モデルの初期化

ModelLoaderにpthファイルを読み込ませてモデルを初期化。
普通ならそれだけでいいが、商用利用ライセンスのないモデルはspandrel_extra_archesからモデルを登録しないといけない。codeformerもそうです。

from spandrel import MAIN_REGISTRY, ModelLoader, ImageModelDescriptor
from spandrel_extra_arches import EXTRA_REGISTRY

# add extra architectures before `ModelLoader` is used
MAIN_REGISTRY.add(*EXTRA_REGISTRY)

# load a model from disk
model = ModelLoader().load_from_file(r"LOL_v2_real.pth")

# make sure it's an image to image model
assert isinstance(model, ImageModelDescriptor)

# send it to the GPU and put it in inference mode
model.cuda().eval()

一度torchでモデルの動作確認をしましょう。
推論方法は以下参照。

ラップモデルを作る

モデルに後処理をつけたnn.moduleを作ります。

import torch

torch_model = model.model.eval().cpu()

class CoreMLModel(torch.nn.Module):
    def __init__(self, m):
        super(CoreMLModel, self).__init__()
        self.m = m

    def forward(self, image):
        pred = self.m(image)
        output = torch.clamp(pred * 255, min=0, max=255)
        return output

coremlmodel = CoreMLModel(torch_model)

ラップモデルを変換します。

import coremltools as ct
import torch

ex = torch.randn(1,3,512,512).cpu()
ts = torch.jit.trace(coremlmodel, ex)
mlmodel = ct.convert(ts, inputs=[ct.ImageType(shape=ex.shape,scale=1/255)],outputs=[ct.ImageType(name="output")])
mlmodel.save("retiinexformerNTIRE.mlpackage")

🐣


フリーランスエンジニアです。
AIについて色々記事を書いていますのでよかったらプロフィールを見てみてください。

もし以下のようなご要望をお持ちでしたらお気軽にご相談ください。
AIサービスを開発したい、ビジネスにAIを組み込んで効率化したい、AIを使ったスマホアプリを開発したい、
ARを使ったアプリケーションを作りたい、スマホアプリを作りたいけどどこに相談したらいいかわからない…

いずれも中間コストを省いたリーズナブルな価格でお請けできます。

お仕事のご相談はこちらまで
rockyshikoku@gmail.com

機械学習やAR技術を使ったアプリケーションを作っています。
機械学習/AR関連の情報を発信しています。

X
Medium
GitHub

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0