ONNXがサポートしている最適化 split について調べました。
最適化をしている場所はここ。
https://github.com/onnx/onnx/blob/master/onnx/optimizer/passes/split.h
コメントを読んでいくと、推論時にグラフの中で定数ととして扱えるものを別のsubgraphに分離する最適化のようです。面白そう。
早速試してみました。add0の入力が定数になるようなグラフを作成します。
左上のAddが定数のみの足し算です。グラフの計算を実行する時に毎回計算する必要が無い部分です。真ん中のAddは、入力Xによって結果が変わるので毎回計算必要です。
最適化は、定数として扱える部分を抜き出すのにsplit_init、推論時に計算が必要な部分を抜き出すのにsplit_predictを指定して、2回呼び出しました。使い方は合ってるはす・・・。
# 最適化パスを指定
passes = ['split_init', ]
optimized_model_init = optimizer.optimize(model_def, passes)
onnx.save(optimized_model_init, 'onnx/split_init.onnx')
passes = ['split_predict', ]
optimized_model_pred = optimizer.optimize(model_def, passes)
onnx.save(optimized_model_pred, 'onnx/split_pred.onnx')
最適化後の初期化部分がこっち。定数同士の加算部分だけを別グラフにし、add0_outという名前で出力しています。初期値扱いで、一度計算すれば再計算は不要です。
推論の計算部分がこっち。上のグラフで計算したadd0_outの結果を使って、全体の計算を行っています。再計算が不要な部分は、別グラフに分離してあるので、定数などは気にせず全部の計算を行えばOKです。
上手く分離されてますね。面白い。
全ソースです。
import numpy as np
import onnx
from onnx import helper, numpy_helper
from onnx import TensorProto
from onnx import optimizer
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1])
A = np.ones((1,), dtype=np.float32)
B = np.ones((1,), dtype=np.float32)
const0 = onnx.helper.make_node(
'Constant',
inputs=[],
outputs=['const_0'],
value=onnx.helper.make_tensor(
name='const_tensor',
data_type=onnx.TensorProto.FLOAT,
dims=A.shape,
vals=A.flatten().astype(float),
),
)
const1 = onnx.helper.make_node(
'Constant',
inputs=[],
outputs=['const_1'],
value=onnx.helper.make_tensor(
name='const_tensor',
data_type=onnx.TensorProto.FLOAT,
dims=B.shape,
vals=B.flatten().astype(float),
),
)
add0 = helper.make_node(
'Add',
inputs = ['const_0', "const_1"],
outputs = ['add0_out'],
)
add1 = helper.make_node(
'Add',
inputs=['X', 'add0_out'],
outputs=['Y']
)
gemm = helper.make_node(
'Gemm',
['X'],
['Y'],
transA=1
)
graph_def = helper.make_graph(
[const0, const1, add0, add1],
'test-model',
[X],
[Y],
initializer=[]
)
model_def = helper.make_model(
graph_def,
producer_name='onnx_example'
)
onnx.save(model_def, 'onnx/split.onnx')
onnx.checker.check_model(model_def)
# 最適化パスを指定
passes = ['split_init', ]
optimized_model_init = optimizer.optimize(model_def, passes)
onnx.save(optimized_model_init, 'onnx/split_init.onnx')
passes = ['split_predict', ]
optimized_model_pred = optimizer.optimize(model_def, passes)
onnx.save(optimized_model_pred, 'onnx/split_pred.onnx')