概要
既存のJava製AndroidネイティブアプリにPyTorchベースの推論処理を実装する必要があったので、備忘として調べた結果をまとめます。
- PyTorch Java APIで推論処理を記述する(本記事)
- Python(Chaquopy)で記述した推論処理を呼び出す
環境
ハードウェア: SHARP SH-T01
プロセッサ: Snapdragon 665 (2GHz & 1.8GHz - Octa)
メモリ: 4GB
OS: Android 12
推論実行方法
PyTorch Java APIを使うと、Javaクラスにそのまま推論を記述することができます。
// input image読み込み
Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST);
// モデル読み込み・推論
Module module = LiteModuleLoader.load(assetFilePath(this, "model.pt"));
final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
モデルの利用にあたり、TorchScriptへの変換が必要です。
TorchScriptについては公式の説明がわかりやすいので割愛しますが、PyTorchで記述したモデルを、C++など他の言語からでも呼び出すことのできる中間表現(Intermediate Representation)と理解しています。
Python版PyTorchで学習したモデルをTorchScriptに変換するには、torch.jit.trace
を使います。(推論処理中に含まれるcontrol flowも変換後のモデルに反映するtorch.jit.script
も存在しますが、ここでは割愛します)
適当なinputを与えて推論させることで元の処理の動作をトレースし、TorchScriptに変換します。
model = torchvision.models.mobilenet_v3_small(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
また、traceしたモデルはoptimize_for_mobile
を使って最適化することができます。
高速化については、公式にPYTORCH MOBILE PERFORMANCE RECIPESというドキュメントが存在し、チューニング方法やベンチマーキングツールの使い方が紹介されています。
optimized_traced_model = optimize_for_mobile(traced_script_module)
備考
AndroidでPyTorchを使うにあたり、公式のデモ集がとても充実しているので、試してみる上での障壁はかなり低いと思います。画像認識、NLP、C++スクリプトの組み込みなど、様々なデモが用意されています。
上記でコードを一部紹介しましたが、PyTorch Java APIを使った推論の最も簡単な例としては、HelloWorldAppを参照するのが良さそうです。
実運用まで考えると、現時点では安定性からPyTorch Java API一択かなという印象で、趣味等で軽く試す場合は別記事で触れたChaquopyが候補となるように感じました。