ONNXがサポートしている最適化 fuse_consecutive_concats を調べてみました。
連続するconcatを一つにまとめてくれます。
最適化前のグラフ
passにfuse_consecutive_concatsを指定して、optimizer.optimizeを呼び出します。
passes = ['fuse_consecutive_concats']
optimized_model = optimizer.optimize(model_def, passes)
最適化後のグラフ
上手く行ってますね。
全ソース
import onnx
from onnx import helper
from onnx import TensorProto
from onnx import optimizer
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 2])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 6])
dropout0 = helper.make_node(
'Dropout',
inputs = ['X'],
outputs = ['dropout0_out'],
)
dropout1 = helper.make_node(
'Dropout',
inputs = ['X'],
outputs = ['dropout1_out'],
)
dropout2 = helper.make_node(
'Dropout',
inputs = ['X'],
outputs = ['dropout2_out'],
)
concat0 = helper.make_node(
'Concat',
axis = -1,
inputs = ['dropout0_out', 'dropout1_out'],
outputs = ['concat0_out']
)
concat1 = helper.make_node(
'Concat',
axis = -1,
inputs = ['concat0_out', 'dropout2_out'],
outputs = ['Y']
)
graph_def = helper.make_graph(
[dropout0, dropout1, dropout2, concat0, concat1],
'test-model',
[X],
[Y]
)
model_def = helper.make_model(
graph_def,
producer_name='onnx_example'
)
onnx.save(model_def, 'onnx/fuse_consecutive_concats.onnx')
onnx.checker.check_model(model_def)
# 最適化パスを指定
passes = ['fuse_consecutive_concats']
optimized_model = optimizer.optimize(model_def, passes)
onnx.save(optimized_model, 'onnx/fuse_consecutive_concats_optimized.onnx')