LoginSignup
39
24

More than 3 years have passed since last update.

TorchScriptを使用してPyTorchのモデルを保存する

Posted at

はじめに

みなさん、PyTorchで学習したモデルを保存するときには以下のようなコードを書いているのではないでしょうか?

usual_save_load.py
torch.save(model.state_dict(), PATH)
model.load_state_dict(torch.load(PATH))

しかし、この方法では学習したモデルで推論を行いたい場合に、loadするパラメータを学習した時のモデルと同じクラスのオブジェクトも作ってやる必要があります。
これは、例えばkaggleのnotebookで推論を実行して提出しなければならないコンペティションなどでは管理の手間が増えて少し面倒です。

ここで提案したいのがTorchScriptです!

TorchScriptとは

ここでTorchScriptについて詳しく書きたいところなのですが、これについて詳しく書き出すとキリがないのと、この記事で触れたい点から少しズレていくので詳細は説明しません。詳しく知りたい方は以下を参照するといいでしょう。
https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html
https://pytorch.org/docs/stable/jit.html
簡単にだけ説明するとPyTorchがTensorflowと比べて弱いとされていたデプロイ周りの強化をするにあたって追加された機能で、PyTorchコードを独自の中間表現に変換することでpythonの学習コードから離れてpythonに依存しないプロセスで使用しやすくする仕組みです。(簡単にまとめすぎて誤解があるかもしれません。叩かないでください...)

TorchScriptを使ってみる

それでは実際にTorchScriptを触っていきたいと思います。普通のモデルとの実行時間の比較、およびモデルのsave, load方法を説明していきたいと思います。

推論時間比較

まずはtorchvisionのモデルを使用してtorchscriptを試してみたいと思います。
jupyter notebookの%timeitを使用して簡単に推論時間の違いも確認してみたいと思います。

time_compare.py
import torch
import torchvision.models as models

input = torch.rand(1, 3, 224, 224).cuda()
model = models.resnet34().eval().cuda()
script_model = torch.jit.script(model).eval().cuda()

print('Normal model inferece')
%timeit model(input)
print('TorchScript model inferece')
%timeit script_model(input)

結果

Normal model inferece
4.07 ms ± 19.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
TorchScript model inferece
2.48 ms ± 5.72 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

TorchScriptにしたほうが推論が早くなっていることがわかりますね。

TorchScriptのsave

次に本題のsaveとloadですが以下のように書けます、簡単ですね。モデルをloadするためにあらかじめモデルを定義しなくてもいいことがわかると思います。kaggle notebookで推論するときにもこのモデルだけアップロードしておけばいいので便利です。

torchscript_save_load.py
PATH = "resnet34.pt"
script_model.save(PATH)
load_model = torch.jit.load(PATH).cuda()

ちなみに推論時間もチェックしてみます。

torchscript_inference.py
load_model = torch.jit.load("resnet34.pt").eval().cuda()
%timeit load_model(input)

結果

2.47 ms ± 17.5 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

saveする前とほとんど同じですね。ちなみに推論結果も完全一致することも確認しました。

複雑なモデルでも変換できるの?

こういうモデルを別形式に変換するのを試したことがある人なら誰でも気になるんじゃないでしょうか?対応していないレイヤなどが結構あって使い物にならないみたいなことも多いです。実際torchscriptが出た直後はよく変換が失敗した気がします。
ということでいかのリポジトリで使えるモデルを色々変換してみたいと思います。
timm(https://github.com/rwightman/pytorch-image-models)
segmentation models pytorch(https://github.com/qubvel/segmentation_models.pytorch)
まずはtimmから

timm_convert_test.py
import timm

model_names = ['gluon_seresnext50_32x4d', 'efficientnet_b3', 'gluon_xception65', 
'gluon_resnext50_32x4d', 'tf_mixnet_s', 'tf_mobilenetv3_small_075', 'tf_efficientnet_lite3',
'mobilenetv2_100', 'mnasnet_b1']
for name in model_names:
    try:
        model = timm.create_model(name, pretrained=False)
        torch.jit.script(model).save("torchscript_model.pt")
        print('{} is clear'.format(name))
    except:
        print('{} is failed'.format(name))

結果

gluon_seresnext50_32x4d is clear
efficientnet_b3 is failed
gluon_xception65 is clear
gluon_resnext50_32x4d is clear
tf_mixnet_s is failed
tf_mobilenetv3_small_075 is failed
tf_efficientnet_lite3 is clear
mobilenetv2_100 is clear
mnasnet_b1 is clear

いくつか失敗してますね。エラーを見てみると以下のように出ていたのでSwishがうまく変換できないのだと思います。

Can't redefine method: forward on class: __torch__.timm.models.layers.activations_me.SwishMe

では次にsegmentation models pytorchを試してみます。

smp_convert_test.py
import segmentation_models_pytorch as smp

model_names = ['resnet34', 'dpn68', 'vgg13', 'densenet121', 'timm-efficientnet-b0', ]
for name in model_names:
    try:
        model = smp.Unet('resnet34', encoder_weights=None)
        torch.jit.script(model).save("torchscript_model.pt")
        print('{} is clear'.format(name))
    except:
        print('{} is failed'.format(name))

結果

resnet34 is failed
dpn68 is failed
vgg13 is failed
densenet121 is failed
timm-efficientnet-b0 is failed

あ、ダメそう...エラーを見るとこんな感じです。

Can't redefine method: forward on class: __torch__.segmentation_models_pytorch.encoders.resnet.ResNetEncoder

変換できないときのtrace

実はTorchScript化するにはscriptとtraceの二通りの方法があります。
traceではモデルのforwardの処理をトレースすることで変換します。要はjitコンパイルですね。サンプルの入力テンソルを用意してやる必要があります。入力のサイズが変わるような場合には不向きです。
ではこれを使って先ほどの変換をもう一度試してみます。

trace_convert.py
import timm
import segmentation_models_pytorch as smp

example_input = torch.rand(1, 3, 224, 224)
model_names = ['gluon_seresnext50_32x4d', 'efficientnet_b3', 'gluon_xception65', 
'gluon_resnext50_32x4d', 'tf_mixnet_s', 'tf_mobilenetv3_small_075', 'tf_efficientnet_lite3',
'mobilenetv2_100', 'mnasnet_b1']

for name in model_names:
    model = timm.create_model(name, pretrained=False)
    try:
        torch.jit.trace(model, example_inputs=example_input).save("torchscript_model.pt")
        print('{} is clear'.format(name))
    except:
        print('{} is failed'.format(name))
        print(e)

model_names = ['resnet34', 'dpn68', 'vgg13', 'densenet121', 'timm-efficientnet-b0', ]
for name in model_names:
    try:
        model = smp.Unet('resnet34', encoder_weights=None)
        torch.jit.trace(model, example_inputs=example_input).save("torchscript_model.pt")
        print('{} is clear'.format(name))
    except:
        print('{} is failed'.format(name))

結果

gluon_seresnext50_32x4d is clear
efficientnet_b3 is failed
gluon_xception65 is clear
gluon_resnext50_32x4d is clear
tf_mixnet_s is failed
tf_mobilenetv3_small_075 is failed
tf_efficientnet_lite3 is failed
mobilenetv2_100 is clear
mnasnet_b1 is clear
resnet34 is clear
dpn68 is clear
vgg13 is clear
densenet121 is clear
timm-efficientnet-b0 is clear

先ほどは変換できなかったUnetも成功していることが分かりますね。まだ失敗しているのもありますが、ここを掘り下げすぎるとキリがなさそうなのでやめておきます。

まとめ

torchscriptを使用することでモデルのデプロイ&推論が簡単になることが分かりました。また、パフォーマンスが改善する場合もあることが分かりました。時々変換できない場合もありますが、そういう時はエラーを読みながら対応するか、今まで通りの方法で頑張りましょう。でもかなり実用的なレベルで変換ができるようになってきているのを感じます。
今回紹介した機能はtorchscriptのほんの一部の機能の使い方であり、torchscriptは他にも様々な可能性を秘めた代物です。また、自分もtorchscriptについての理解は十分とは言えないので間違いなど見つけた方は気軽に指摘ください!

39
24
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
39
24