Help us understand the problem. What is going on with this article?

PyTorch→ONNX→NNVM

More than 1 year has passed since last update.

はじめに

こちらの記事で紹介したNNVMですが、記事内であげていた

  • OpenCLビルドが通らない
  • PyTorchからのONNX exportが通らない

という問題は開発が進み解消されましたので、その分を書きます。

今回はPyTorchの学習済モデルをONNX経由でLLVM/OpenCLビルドして、PyTorchと実行速度を比較してみます。OpenCLのターゲットは MacBook Pro のIntel HD Graphics4000です。

学習済モデルの確認

torch-visionの学習済モデルを使います。VGG,Alexnet,Resnet,Squeezenet,Inception,Densenetが用意されています。下記のように簡単に学習済モデルが得られます。

torch_model = models.vgg16(pretrained=True)

まず PyTorch で推論して動作を確認します。OpenCVで読み込んだ画像をfloatに変換したりしてモデル入力に合わせるサポート関数を作ります。

def img2dat(img):
    inshape=(1,3,224,224)
    rgb = cv2.cvtColor(img,cv2.COLOR_BGR2RGB).astype(np.float32)/255.0
    rgb = normalize(rgb)
    rgb = cv2.resize(rgb,(224,224))
    dat = rgb.transpose(2,0,1)
    dat = dat.reshape(inshape)
    return dat

ドキュメントによると学習時の入力は正規化が前提だそうなので、正規化関数を書きます。PyTorchにもユーティリティ関数はありますが、のちのちNNVMにするときにPyTorchフリーにしたいので。

def normalize(img):
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    for c in range(3):
        img[:,:,c] =  (img[:,:,c]-mean[c])/std[c]
    return img

以上の準備ができたら、torch_from_numpy()でテンソルに変換し、Variableにしてからモデルに投入します。

img = cv2.imread("image.jpg")
dat=torch.from_numpy(img2dat(img))
x = Variable(dat)
out=torch_model(x)

クラス番号の文字列への変換はimagenet1000_clsid_to_human.txtを使います。下記でclass_dictというディクショナリを作ります。

classfile="imagenet1000_clsid_to_human.txt"
f=open(classfile).read()
exec("class_dict="+f)

出力はImagenet1000クラスの確率ベクタをバッチ方向に並べたテンソルです。dataでテンソルにnp.array()としてアクセスできます。推論時はバッチ数は1個なので、out.data[0]で推論結果の確率ベクタが得られます。argmaxを取れば最も可能性の高いクラス番号が得られます。これを先程のclass_dictを使って文字列に変換します。・・・とまあ、文章だと長いですがコードだと下記一行です。

result=class_dict[np.argmax(out.data[0])]

ONNXへの変換

下記要領でExportできます。ダミーの入力を入れる必要があります。

torch_model.train(False)
x = Variable(torch.randn(1, 3, 224, 224), requires_grad=True)
torch_out = torch.onnx._export(torch_model,x,target_path,export_params=True)

ONNXからビルド

下記要領でできます。shape_dictで入力テンソルのポート名とサイズを規定します。

onnx_graph = onnx.load(onnx_file_name)
sym, params = nnvm.frontend.from_onnx(onnx_graph)

target="llvm"
shape_dict = {'input_0': (3,224,224)}

deploy_graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, params=params)

lib.export_library("deploy.dylib")
with open("deploy.json", "w") as fo:
    fo.write(deploy_graph.json())
with open("deploy.params", "wb") as fo:
    fo.write(nnvm.compiler.save_param_dict(params))

ビルド結果は下記3ファイルに出力されます。

  • dylib: Dynamic Linkライブラリ
  • json: グラフ構造
  • params: パラメータ

torch-vision→ONNX→NNVMという経路での
ビルドについてはいろいろ試したところ、現時点では

  • AlexNet
  • VGG

しか通りませんでした。その他のモデルは下記のようにそれぞれ問題がありました。

  • ResNet: そもそもオリジナルモデルで認識がうまくいかない。
  • SqueezeNet max_pool2dがないといわれONNX変換がうまくいかない。
  • Inception_v3: Output size is too small とでてONNX変換がうまくいかない。
  • densenet: CUDA必須のためCPUではONNX変換がうまくいかない。

ビルド済モデルをつかった推論

下記のようなコードで推論が行えます。

loaded_lib = tvm.module.load("deploy.dylib")
loaded_json = open("deploy.json").read()
loaded_params = bytearray(open("deploy.params", "rb").read())

dat=img2dat(img)

module.set_input('input_0', tvm.nd.array(dat))
module.run()
out=module.get_output(0, out=tvm.nd.empty(outshape))
out=out.asnumpy()

作業まとめ

GitHubに以上のコードをまとめてあります。

下記コマンドで、それぞれ、vgg16のonnx書き出し、llvmでのビルド、ビルド結果を使った推論までが実行できます。

python export.py vgg16
python compile.py llvm vgg16
python replay.py orange.jpg llvm vgg16

速度比較

VGG16で、素のPyTorch、LLVM, OpenCLを比較しました。20回ずつの平均です。

フレームワーク 実行時間[sec]
PyTorch 0.71
LLVM/NNVM 17.41
OpenCL/NNVM 2.92

LLVM,OpenCLの比較はかなり差が出ているものの、素のPyTorchが一番速いという結果です。こちらに記載があるように、現時点では演算スケジューリングの最適化について、ARMを優先して対応しているようで、x86はまだ行っていないとのことでしたのでまだ実力は出しきれてないと思います。それよりはこういう枠組が作られたことのほうが偉大かと思います。

RaspberryPiへのデプロイについて

さて、そういう流れからすると、当然RaspberryPiで動かしたいところですが、残念ながら本記事には間に合いませんでした。こちらのドキュメントにラズパイへのデプロイ方法が書かれています。クロスビルドして、RPC経由でライブラリやパラメータを転送し、実行するようです。

ただ、手元の RaspberryPi Zero W で試していますが現状まだうまく行ってません。

target="llvm -target=armv6-none-linux-gnueabihf -mcpu=arm1176jzf-s -mattr=+neon"

あたりで試しているところですが、

check failed: code == RPCCode::kReturn code=4

等のエラーで止まります。今回はあきらめました。。。

Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away