ONNXとは
ONNX(Open Neural Network Exchange)とはさまざまなAIプラットフォームで開発されるモデルを統一的に表現できることを目指したフォーマット。
引用: https://docs.ultralytics.com/ja/integrations/onnx/
AIアクセラレータとONNX
HaloやBlaizeなどのいわゆるAIアクセラレータは専用ハードウェアのためにコンパイルされた形式のファイルを実行するケースが多い。
その際コンパイル元としてONNXフォーマットが使用されることが多い。
引用: https://docs.ultralytics.com/ja/integrations/onnx/
ONNX Runtimeについて
ONNXフォーマットはプラットフォームやフレームワーク間の中間形式として使用されることが多いが、ONNX Runtimeを使用することでONNXフォーマットを用いて直接推論や学習を行うことも可能である。
PythonでのONNXファイル操作例
ここではpythonのonnxモジュールを用いてONNXフォーマットを操作する様子を例示する。
pytorchモデルのONNXフォーマットexport
torch_model # PyTorch model
onnx_path = 'model.onnx'
torch.onnx.export(traced_model, inp[0], onnx_path,
export_params=True,
training=torch.onnx.TrainingMode.EVAL,
opset_version=18)
ONNXフォーマットファイルからモデル読み込み
onnx_model = onnx.load('model.onnx')
モデルのノード情報列挙
for node in model.graph.node:
print("node name:", node.name)
print("node input:", node.input)
print("node output:", node.output)
print("node attribute:", node.attribute)
ノードの編集
実際業務で用いた例。ノードに異常な値が入力されているのが見つかったため異常な入力値を取り除いた。(Padノードの仕様として取り除かれた入力にはデフォルト値が与えられる)
for node in onnx_model.graph.node:
if node.op_type == 'Pad':
node.input.remove(node.input[2])
onnx.save(onnx_model, 'model_fixed.onn') # 修正モデルをファイル保存
ONNX Runtimeによるモデル推論実行
session = onnxruntime.InferenceSession('model.onnx')
input_dic = {input.name: None for input in session.get_inputs()}
output_names = [output.name for output in session.get_outputs()]
dummy_input = np.random.rand(1,3,640,640).astype(np.float32)
input_dic['input.1'] = dummy_input
result = session.run(output_names, input_dic)
ONNXファイルの可視化
デモ