ONNXがサポートしている最適化 fuse_matmul_add_bias_into_gemm を調べてみました。
最適化してる場所はここ
https://github.com/onnx/onnx/blob/master/onnx/optimizer/passes/fuse_matmul_add_bias_into_gemm.h
コメントの方がわかりやすい。
// Before:
// Z = MatMul(X, Y)
// A = Z + Bias
// After:
// A = Gemm(X, Y, Bias)
//
// the pass can handle the case when:
// case 1: Bias is 1D tensor and Bias.dim[0] == Z.dim[1]
// case 2: Bias is 2D tensor and Bias.dim[0] == Z.dim[0] or 1
// and Bias.dim[1] = Z.dim[1]
次元があっていれば、連続する matmul -> add をgemmに置き換えます。
最適化前のグラフ
passにfuse_matmul_add_bias_into_gemmを設定して、optimizer.optimizeを呼び出します。
passes = ['fuse_matmul_add_bias_into_gemm']
optimized_model = optimizer.optimize(model_def, passes)
最適化後のグラフ
上手く行きました。
全ソース
import onnx
from onnx import helper
from onnx import TensorProto
from onnx import optimizer
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 5])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [5, 1])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [1, 1])
Z = helper.make_tensor_value_info('Z', TensorProto.FLOAT, [1, 1])
matmul = helper.make_node(
'MatMul',
inputs = ['X', 'Y'],
outputs = ['matmul_out'],
)
add = helper.make_node(
'Add',
inputs = ['matmul_out', 'B'],
outputs = ['Z'],
)
graph_def = helper.make_graph(
[matmul, add],
'test-model',
[X, Y, B],
[Z]
)
model_def = helper.make_model(
graph_def,
producer_name='onnx_example'
)
onnx.save(model_def, 'onnx/fuse_matmul_add_bias_into_gemm.onnx')
onnx.checker.check_model(model_def)
# 最適化パスを指定
passes = ['fuse_matmul_add_bias_into_gemm']
optimized_model = optimizer.optimize(model_def, passes)
onnx.save(optimized_model, 'onnx/fuse_matmul_add_bias_into_gemm_optimized.onnx')