LoginSignup
2
0

More than 1 year has passed since last update.

Androidでの推論実行方法 Java編 【PyTorch】

Last updated at Posted at 2022-09-24

概要

既存のJava製AndroidネイティブアプリにPyTorchベースの推論処理を実装する必要があったので、備忘として調べた結果をまとめます。

  1. PyTorch Java APIで推論処理を記述する(本記事)
  2. Python(Chaquopy)で記述した推論処理を呼び出す

環境

ハードウェア: SHARP SH-T01
プロセッサ: Snapdragon 665 (2GHz & 1.8GHz - Octa)
メモリ: 4GB
OS: Android 12

推論実行方法

PyTorch Java APIを使うと、Javaクラスにそのまま推論を記述することができます。

https://github.com/pytorch/android-demo-app/blob/master/HelloWorldApp/app/src/main/java/org/pytorch/helloworld/MainActivity.java
// input image読み込み
Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
        TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST);

// モデル読み込み・推論
Module module = LiteModuleLoader.load(assetFilePath(this, "model.pt"));
final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();

モデルの利用にあたり、TorchScriptへの変換が必要です。
TorchScriptについては公式の説明がわかりやすいので割愛しますが、PyTorchで記述したモデルを、C++など他の言語からでも呼び出すことのできる中間表現(Intermediate Representation)と理解しています。

Python版PyTorchで学習したモデルをTorchScriptに変換するには、torch.jit.traceを使います。(推論処理中に含まれるcontrol flowも変換後のモデルに反映するtorch.jit.scriptも存在しますが、ここでは割愛します)

適当なinputを与えて推論させることで元の処理の動作をトレースし、TorchScriptに変換します。

https://github.com/pytorch/android-demo-app/blob/master/HelloWorldApp/trace_model.py
model = torchvision.models.mobilenet_v3_small(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)

また、traceしたモデルはoptimize_for_mobileを使って最適化することができます。

Torch mobile_optimizer package does several optimizations with the scripted model, which will help to conv2d and linear operations. It pre-packs model weights in an optimized format and fuses ops above with relu if it is the next operation.

高速化については、公式にPYTORCH MOBILE PERFORMANCE RECIPESというドキュメントが存在し、チューニング方法やベンチマーキングツールの使い方が紹介されています。

https://github.com/pytorch/android-demo-app/blob/master/HelloWorldApp/trace_model.py
optimized_traced_model = optimize_for_mobile(traced_script_module)

備考

AndroidでPyTorchを使うにあたり、公式のデモ集がとても充実しているので、試してみる上での障壁はかなり低いと思います。画像認識、NLP、C++スクリプトの組み込みなど、様々なデモが用意されています。

上記でコードを一部紹介しましたが、PyTorch Java APIを使った推論の最も簡単な例としては、HelloWorldAppを参照するのが良さそうです。

実運用まで考えると、現時点では安定性からPyTorch Java API一択かなという印象で、趣味等で軽く試す場合は別記事で触れたChaquopyが候補となるように感じました。

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0