LoginSignup
0
0

ScatterElementsだけ動かすモデルをOnnx Runtimeを使って作成する

Posted at

ScatterElementsだけ動かして試す

やりたいこと

  • ScatterElementsの動作の理解のため、ONNXの仕様書に書かれていることを再現できるパイソンスクリプトを作成する事
  • ONNX runtimeの使い方を理解する事 (オペレータを1つだけ使って、計算を行うグラフを作成し、ONNXファイルを出力する)

data = [
    [0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0],
]
indices = [
    [1, 0, 2],
    [0, 2, 1],
]
updates = [
    [1.0, 1.1, 1.2],
    [2.0, 2.1, 2.2],
]
output = [
    [2.0, 1.1, 0.0]
    [1.0, 0.0, 2.2]
    [0.0, 2.1, 1.2]
]

作成したスクリプト

import numpy as np
import onnx
import onnxruntime

#### input ####
data = np.array([
	[0.0, 0.0, 0.0],
	[0.0, 0.0, 0.0],
	[0.0, 0.0, 0.0],
], dtype=np.float32)

indices = np.array([
	[1, 0, 2],
	[0, 2, 1],
])

updates = np.array([
	[1.0, 1.1, 1.2],
	[2.0, 2.1, 2.2],
], dtype=np.float32)

#### output(のshape) ####
data_out = np.array([
	[0.0, 0.0, 0.0],
	[0.0, 0.0, 0.0],
	[0.0, 0.0, 0.0],
])

#### Attributes ####
axis = 0

#### graphの作成 ####
nodes = [
	onnx.helper.make_node(
		"ScatterElements",
		inputs=["data", "indices", "updates"],
		outputs=["data_out"],
		axis=axis,
		reduction="add", # add, max...
		)
]
print("\n=== nodes ===:\n", nodes)

inputs = [
	# make_tensor_value_info (name, elem_type, shape).
	# テンソルの型情報をテンソルのshapeとデータ型から作成する。
    onnx.helper.make_tensor_value_info("data", onnx.TensorProto.FLOAT, data.shape),
    onnx.helper.make_tensor_value_info("indices", onnx.TensorProto.INT64, indices.shape),
    onnx.helper.make_tensor_value_info("updates", onnx.TensorProto.FLOAT, updates.shape),
]
print("\n=== inputs ===:\n", inputs)

outputs = [
	# テンソルの型情報をテンソルのshapeとデータ型から作成する。
    onnx.helper.make_tensor_value_info("data_out", onnx.TensorProto.FLOAT, data_out.shape),
]

graph = onnx.helper.make_graph(nodes, "graph", inputs, outputs)
model = onnx.helper.make_model(
    graph, ir_version=9, opset_imports=[onnx.helper.make_operatorsetid("", 19)]
)
print("\n=== outputs ===:\n", outputs)

onnx.save(model, "test.onnx")

#### 計算の実行 ####
print(model.SerializeToString())
ses = onnxruntime.InferenceSession(
    model.SerializeToString(), providers=["CPUExecutionProvider"]
)

outputs = ses.run(None, {"data": data, "indices": indices, "updates": updates})
print(outputs)

実行結果

https://github.com/onnx/onnx/blob/main/docs/Operators.md#ScatterElements
に書かれている通りのoutputを得ることができた。

  :
[array([[2. , 1.1, 0. ],
       [1. , 0. , 2.2],
       [0. , 2.1, 1.2]], dtype=float32)]
(base) mozaki@yoshitsune:~/004_HDD/001_CODE/python_study/003_add$ 
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0