1
1

More than 1 year has passed since last update.

TrOCRのモデルをONNX化するのにハマった点

Last updated at Posted at 2022-10-27

decoderとencoder別々でONNX化しなければならなかった。

TrOCRのモデルをONNX formatに変換するときにハマったのでメモ。
Pytorchのチュートリアル「EXPORTING A MODEL FROM PYTORCH TO ONNX AND RUNNING IT USING ONNX RUNTIME」 をみるとtorch.onnx.exportに作ったモデルを通せば良いんだなーとだけ思っていたが、TrOCRのようなtransfomerが入ってるモデルだと、エンコーダーとデコーダーでそれぞれtorch.onnx.exportする必要がある。例えばこんな感じ↓ 

torch.onnx.export(trocr_encoder,               
                  x,   
                  "trocr_encoder.onnx",   
                  export_params=True,
                  opset_version=10, 
                  do_constant_folding=True, 
                  input_names = ['input'],
                  output_names = ['output'],
                  dynamic_axes={'input' : {0 : 'batch_size'},
                                'output' : {0 : 'batch_size'}})


with torch.no_grad():
    torch.onnx.export(
        trocr_decoder,
        args=(input_ids, attention_mask, encoder_hidden_states),
        f='trocr_decoder.onnx',
        opset_version=15,
        input_names=[
            "input_ids",
            "attention_mask",
            "encoder_hidden_states",
        ],
        output_names=["logits"],
        dynamic_axes={
            "input_ids": {0:'batch', 1:'sequence_length'},
            "attention_mask": {0:'batch', 1:'sequence_length'},
            "encoder_hidden_states": {0:'batch'},
            "logits": {0:'batch'},
        },
        do_constant_folding=True,
    )

これだけで数時間を無駄にした。

参考

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1