ONNXがサポートしている最適 fuse_transpose_into_gemm を調べてみました。
最適化をしている場所はここ。
https://github.com/onnx/onnx/blob/master/onnx/optimizer/passes/fuse_transpose_into_gemm.h
bool patternMatchPredicate(Node* node) override {
return node->kind() == kGemm;
}
bool runTransform(Node* n, Graph&, NodeDestroyType& destroy_current)
override {
const std::vector<int64_t> simple_trans_perm({1, 0});
destroy_current = NodeDestroyType::DestroyZero;
bool ret_val = false;
for (size_t i : {0, 1}) {
auto inp = n->inputs()[i];
auto trans = i == 0 ? ktransA : ktransB;
if (inp->node()->kind() == kTranspose &&
inp->node()->is(kperm) == simple_trans_perm) {
n->replaceInput(i, inp->node()->input());
n->i_(trans, n->hasAttribute(trans) ? !n->i(trans) : 1);
if (inp->uses().size() == 0) {
inp->node()->destroy();
ret_val = true;
}
}
}
return ret_val;
}
};
gemmの入力がtransposeで、transposeのpermが[0, 1]の時、transposeをgemmにfuseするようです。
グラフ作って試してみました。
transpose = helper.make_node(
'Transpose',
perm = [1, 0],
inputs = ['X1'],
outputs = ['transpose_out'],
)
gemm = helper.make_node(
'Gemm',
['transpose_out', 'X2'],
['Y'],
transA=1
)
最初のTransposeが単純な転置、続くgemmの入力でも転置をいれています。
最適化前のグラフがこう。
passにfuse_transpose_into_gemmを設定して、optimizer.optimizeを実行します。
passes = ['fuse_transpose_into_gemm']
optimized_model = optimizer.optimize(model_def, passes)
gemmの情報を見てみると、上で設定したtransAが0(転置無し)になっています。
op_type: "Gemm"
attribute {
name: "transA"
i: 0
type: INT
}
上手くいってますね。
全ソースです。
import numpy as np
import onnx
from onnx import helper, numpy_helper
from onnx import TensorProto
from onnx import optimizer
X1 = helper.make_tensor_value_info('X1', TensorProto.FLOAT, [5, 5])
X2 = helper.make_tensor_value_info('X2', TensorProto.FLOAT, [5, 5])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [5, 5])
transpose = helper.make_node(
'Transpose',
perm = [1, 0],
inputs = ['X1'],
outputs = ['transpose_out'],
)
gemm = helper.make_node(
'Gemm',
['transpose_out', 'X2'],
['Y'],
transA=1
)
graph_def = helper.make_graph(
[transpose, gemm],
'test-model',
[X1, X2],
[Y]
)
model_def = helper.make_model(
graph_def,
producer_name='onnx_example'
)
onnx.save(model_def, 'onnx/fuse_transpose_into_gemm.onnx')
onnx.checker.check_model(model_def)
# 最適化パスを指定
passes = ['fuse_transpose_into_gemm']
optimized_model = optimizer.optimize(model_def, passes)
onnx.save(optimized_model, 'onnx/fuse_transpose_into_gemm_optimized.onnx')
node = optimized_model.graph.node[0]
print(node)