LoginSignup
2
7

More than 3 years have passed since last update.

BERT など Transformer 系 model を TorchScript を経由して PyTorch Mobile(Android) で動作させ NLP したい

Last updated at Posted at 2019-10-14

背景

  • BERT など Transformer 系の NLP を Android で動かしたい.
  • PyTorch 1.3 により, PyTorch Mobile が対応されてきており, PyTorch のモデルをモバイルで動かせる機運がたかまる
    • TFLite も頑張ればできそうではあるが..(or libtensorflow を Android で動かすとか)
  • huggingface python-transformers https://github.com/huggingface/transformers が TorchScript 対応しているのでこれをベースにする

ちなみに iOS(Core-ML)はすでにあります.

状況

JIT Tracing により .pt を作れるところまで確認

変換の流れ

v1.3 の PyTorch Mobile から, モバイルで動かすには onnx 経由で caffe2 で動かすではなく, TorchScript でモバイル実行(C++ 実行と近い?)が推奨になっているようです

pytorch コード(python) -> TorchScript -> Mobile 実行(libtorch? caffe2?)

TorchScript(JIT) とは?

TensorFlow でいう tflite や, frozen model データ(.pb)に相当でしょうか.

フロントエンド言語としては, Python 言語の subset となります.
ここからさらに, IR(中間言語)として LLVM IR っぽい感じの言語に落としこまれます.

PyTorch ソースコードをちらっと見ると, ゼロから python ライクな言語(と LLVM IR っぽい IR 言語)を実装していて Facebook の開発力を感じますね.

tracing と scripting の二種類があります(内部的には同じ?)

制御構造などを持たないものは tracing で対応, 制御構造を持つもの(e.g. RNN 系)は scripting で対応のようです.

なんとなく Mobile で動かすには tracing だけで変換できるようにするのが推奨そうな気がしますね.

アノテーション

TorchScript に変換するために, 以前は python コードで, class は torch.jit.ScriptModule を継承したり, @ デコレータ指定が必要でしたが, v1.2 からほぼ既存の python コードそのままを扱えるようになりました.

型情報が必要な場合は, mypy 形式でコメントに型情報を記載するか, python 3.6 からの typing モジュールを使って型付けする必要があります.

HuggingFace pytorch-transformers のコードを見ると, JIT 用のアノテーションはありませんし, typing も使っていません. TorchScript 化を考慮してうまく python で記述していますね.

HuggingFace pytorch-transformers を TorchScript 変換する

あまり情報はありませんが, README.md にちらっと記述があります.

    # Models are compatible with Torchscript
    model = model_class.from_pretrained(pretrained_weights, torchscript=True)
    traced_model = torch.jit.trace(model, (input_ids,))

固定長の input を用意して, torch.jit.trace 呼び出しています.

実際に BertModel('bert-base-uncased') は変換できるのを確認しました. およそ 400 MB 程度のファイルサイズとなりました.

libtorch(C++)でモデルを読み込む

PC(Linux) 環境ですが,

torch.jit.load

で無事に trace した .pt ファイルをロードできるのを確認しました. ロードに数秒かかります.

TorchScript 拡張子

セーブ時はファイル拡張子見ていないようなので, 実際のところ .pt, pth のどれでも同じになります.
実際のところは NPZ(numpy zipfile)と同じように, 複数のファイルを非圧縮でまとめた zip のようです.

TODO

  • libtorch(C++) で .pt を読み込みで動作するか確認する
  • Android で動かす
    • asset フォルダからの load API を調べる
  • scripting にして, 可変長入力に対応できるか試す
2
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
7