背景
pytorch モデル(特に音声処理系)を trace(TensorFlow でいう freezed model, tflite model)したい.
pytorch 自体がある程度は型推論してはくれますが, 限界があります.
特に音声処理系だと, neural 部分以外にもいろいろとコードがあります. また, 動的配列とか再帰などがあるため, 適切に型付けが必要になります.
環境
pytorch v1.4.0 を想定します(2020/04/04 時点での最新 stable).
JIT モデル
- TorchScript : Python のサブセットでスクリプト的. 型は必要
- 入力のテンソルが動的に変わる(特に音声系など時系列データやテキストデータなど)ときに使う
- 実行時にコンパイルが必要になりのかしら?(その場合, 起動にいくらか時間かかるかも)
- Traced model : TensorFlow で言う freezed model, tflite model に近い?
- 入力テンソルのサイズが固定のときに使える.
- 推論が早くなる... かも?
- 構文には制限があるかも
traced model が理想で, 次点で TorchScript でしょうか.
一応両方組み合わせることもできるようです.
環境
- Python 3.6 or later(Python で型付けする
typing
module と, 型アノテーション構文が使える)
型を調べる
現在の Python スクリプトの型をステップ実行して調べる手があります.
Python デバッガや ipython を使うといいのでしょうか...
Juyter lab とかでも表示してくれるのかな?
筆者は基本 vim + コマンドライン実行しか知らないなので, それぞれ型が欲しいところに print(type(x)) などを挿入して調べています...
型付けする
python2 などや 3.5 用に, torch.jit.annotate
を使ったり, コメントに型を記述する手もありますが, typing
モジュールを使い, python 構文内に型アノテーションするのが推奨です.
サンプル
def forward(self, x: List[torch.Tensor]) -> Tuple[torch.Tensor]:
my_list: List[Tuple[int, float]] = []
という感じでいけます.
Optional
C++ でいう std::optional のように, None or なにか型を持つ, は Optional[T]
でいけます.
Optional[int]
@torch.jit.export
通常は forward()
メソッドと, forward から呼ばれる関数しか JIT compile しませんが, @torch.jit.export
デコレータを使うことで, 明示的にメソッドをエクスポート(JIT compile させる)できます
(forward
は暗に @torch.jit.export
がデコレートされている)
nn.ModuleList
nn.ModuleList(配列)を
...
self.mods = nn.Modulelist([...])
for i in range(10):
self.mods[i](x)
のように配列インデックスでアクセスするのは現状できません.
[jit] Can't index nn.ModuleList in script function #16123
https://github.com/pytorch/pytorch/issues/16123
とりあえずは __constants__
と, for mod in modules
のような形で iterate すれば対応できると思われますが, 複数の nn.ModuleList を使う場合は専用の class を再定義することになるかと思います.
しかし, この場合 network op の定義が変わり(state_dict の名前が変わる), pretrained model の weight をうまく対応しなおしが必要になります.
また, v1.5.0(v1.6.0?) では self.mods[0]
など, 定数での配列インデックスは対応し初めていますが,
[JIT] Add modulelist indexing for integer literal #29236
https://github.com/pytorch/pytorch/pull/29236
TorchScript の evaluate(libtorch 側)でエラーになりました.
(getattr(xxx, 10)
のような式になっていて, これが runtime 時にパースできない)
もう少し成熟をまつ必要があります.
さらに, nn.ModuleList を iterate する場合は, reversed
による逆順イテレーションはサポートしていません.
print, assert
TorchScript では, print
, assert
が TorchScript でも動きます(trace ではダメかも).
デバッグ用にメッセージを出すなどに使えます.
JIT でスクリプトが実行されているか
scripting で実行されている場合に, 処理を一部省きたかったり, None が想定されるが non None のときは型は任意になるので Optional[T]
で型付けができないなどで, 処理を分けたいケースがあります.
torch.jit.is_scripting()
は, ランタイム時にスクリプト実行(libtorch で実行)されているかどうかを判定のため, トレース時(コンパイル時)かどうかの判定にはつかえません.
なにかデコレータなどあるといいのですが, 現状はなさそうです.
したがって, 関数単位で python 用か TorchScript 用かを切り替えることはできなさそうです.
torchscript ドキュメントにあるように,
@torch.jit.ignore
def forward_pytorch():
...
def forward_for_torchscript():
...
def forward():
if torch.jit.is_scripting():
forward_for_torchscript()
else
foward_pytorch()
と, ある程度コードの編集が必要になります. ただ, 式(ステートメント)自体はトレースの対象となるため, numpy() など使っているコードがあるとコンパイルできずにエラーになります.
上記のように関数化して, @torch.jit.ignore
に pytorch(+ numpy)で実行するコードを移行する必要があります(@torch.jit.unused
だとコンパイル対象になるため)
nn.RNN での _flatten_parameters()
内部で使っている GeneratorExp が TorchScript に対応していません.
GPU 用にメモリレイアウトを調整する用なので, 無視(コード削除)しても大丈夫でしょう.
ほか
-
@torch.jit.unused
decorator -
@torch.jit.ignore
decorator
unused や ignore は, forward
が定義されているが, これは学習用のためなどで, TorchScript では無視したいときに使える.
unused と ignore の違いは, unused ではメソッドを呼ぶと例外を出すが, ignore はメソッドを呼んでもなにもしない. 基本は unused を使ったほうがよさそう.
F.pad(x, [0, 0, maxlen - m.size(2), 0])
^^^^^^^^^^^^^^^^^^^^^
みたいなのが, List[int]
と型推論されませんでした. (m は torch.Tensor). 明示的に int 型の変数つくるなどして解決.
.numpy()
.numpy()
は使えない模様. e.g. x.cpu().data.numpy()
. C++ 側では aten がうまく処理してくれるから, .numpy()
は使わなくても大丈夫かも...?
また, トレースするコード内で numpy 関数を使っていないことがのぞましい.
- 定数
T.B.W.
エントリとなる forward
の返り値の型
エントリとなる Model の forward
が返す型を明示的に指定しておくとよさそうです.
C++ 側で実行するときにどの型かわかりやすくするためです.
(型が合わないと実行時に assertion が出る)
ひとつの Tensor だけ返すときは torch::Tensor
として扱えます.
複数のテンソルを返す場合は Tuple になりますので,
model.forward(inputs).toTuple()
とします.
TODO
def myfun(x, activation = None):
if activation:
x = activation(x)
return x
myfun(activation=torch.relu)
のように, 任意の関数 or None みたいなのを扱いたいときどうすればいいか?
class Mod:
def forward(self, x, alpha=1.0):
...
class Model:
def __init__(self):
self.mod = Mod()
def forward(self, x):
self.mod.forward(x)
...
のように, 内部で呼んでいる class の forward に optional 引数がある.
実際のところ使っていなければ削除すればよい?