ScatterElementsだけ動かして試す
やりたいこと
- ScatterElementsの動作の理解のため、ONNXの仕様書に書かれていることを再現できるパイソンスクリプトを作成する事
- ONNX runtimeの使い方を理解する事 (オペレータを1つだけ使って、計算を行うグラフを作成し、ONNXファイルを出力する)
data = [
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
]
indices = [
[1, 0, 2],
[0, 2, 1],
]
updates = [
[1.0, 1.1, 1.2],
[2.0, 2.1, 2.2],
]
output = [
[2.0, 1.1, 0.0]
[1.0, 0.0, 2.2]
[0.0, 2.1, 1.2]
]
作成したスクリプト
import numpy as np
import onnx
import onnxruntime
#### input ####
data = np.array([
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
], dtype=np.float32)
indices = np.array([
[1, 0, 2],
[0, 2, 1],
])
updates = np.array([
[1.0, 1.1, 1.2],
[2.0, 2.1, 2.2],
], dtype=np.float32)
#### output(のshape) ####
data_out = np.array([
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
])
#### Attributes ####
axis = 0
#### graphの作成 ####
nodes = [
onnx.helper.make_node(
"ScatterElements",
inputs=["data", "indices", "updates"],
outputs=["data_out"],
axis=axis,
reduction="add", # add, max...
)
]
print("\n=== nodes ===:\n", nodes)
inputs = [
# make_tensor_value_info (name, elem_type, shape).
# テンソルの型情報をテンソルのshapeとデータ型から作成する。
onnx.helper.make_tensor_value_info("data", onnx.TensorProto.FLOAT, data.shape),
onnx.helper.make_tensor_value_info("indices", onnx.TensorProto.INT64, indices.shape),
onnx.helper.make_tensor_value_info("updates", onnx.TensorProto.FLOAT, updates.shape),
]
print("\n=== inputs ===:\n", inputs)
outputs = [
# テンソルの型情報をテンソルのshapeとデータ型から作成する。
onnx.helper.make_tensor_value_info("data_out", onnx.TensorProto.FLOAT, data_out.shape),
]
graph = onnx.helper.make_graph(nodes, "graph", inputs, outputs)
model = onnx.helper.make_model(
graph, ir_version=9, opset_imports=[onnx.helper.make_operatorsetid("", 19)]
)
print("\n=== outputs ===:\n", outputs)
onnx.save(model, "test.onnx")
#### 計算の実行 ####
print(model.SerializeToString())
ses = onnxruntime.InferenceSession(
model.SerializeToString(), providers=["CPUExecutionProvider"]
)
outputs = ses.run(None, {"data": data, "indices": indices, "updates": updates})
print(outputs)
実行結果
https://github.com/onnx/onnx/blob/main/docs/Operators.md#ScatterElements
に書かれている通りのoutputを得ることができた。
:
[array([[2. , 1.1, 0. ],
[1. , 0. , 2.2],
[0. , 2.1, 1.2]], dtype=float32)]
(base) mozaki@yoshitsune:~/004_HDD/001_CODE/python_study/003_add$