ONNXがサポートしている最適化 fuse_consecutive_reduce_unsqueeze を調べてみました。
最適化をしているところはここ。
https://github.com/onnx/onnx/blob/master/onnx/optimizer/passes/fuse_consecutive_reduce_unsqueeze.h
bool patternMatchPredicate(Node* node) override {
// check that the current node is of type Unsqueeze and has defined axes
bool cur_node_check =
node->kind() == kUnsqueeze && node->hasAttribute(kaxes);
if (cur_node_check) {
Node* prev_node = node->input()->node();
// check that the previous node a reduction operator and has defined
// axes/keepdims
bool reduction_node_check = reduction_operators.find(prev_node->kind()) !=
reduction_operators.end() &&
prev_node->hasAttribute(kaxes) && prev_node->hasAttribute(kkeepdims);
if (reduction_node_check) {
// insure that keepdims is set to false currently
return prev_node->i(kkeepdims) == 0 && node->is(kaxes) == prev_node->is(kaxes);
}
}
return false;
}
最適化できるかどうかの判断が ちょっと難しい
- Reduction -> Unsqueezeの並びであること
- ReductionのoperatorでaxisとkeepdimsのAttributeを持っている事
- Reductionのkeepdimsが0で、Reductionのaxisとsqueezeのaxesが同じこと
その条件を満たすグラフを作りました。
passにfuse_consecutive_reduce_unsqueezeを指定して、optimizer.optimizeを呼び出します。
passes = ['fuse_consecutive_reduce_unsqueeze']
optimized_model = optimizer.optimize(model_def, passes)
最適化後のグラフ
上手く消えてますね。
全ソース
import onnx
from onnx import helper
from onnx import TensorProto
from onnx import optimizer
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 4])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 4])
reducemax = helper.make_node(
'ReduceMax',
keepdims = 0,
axes = [1],
inputs = ['X'],
outputs = ['reducemax_out'],
)
unsqueeze = helper.make_node(
'Unsqueeze',
axes = [1],
inputs = ['reducemax_out'],
outputs = ['Y'],
)
graph_def = helper.make_graph(
[reducemax, unsqueeze],
'test-model',
[X],
[Y]
)
model_def = helper.make_model(
graph_def,
producer_name='onnx_example'
)
onnx.save(model_def, 'onnx/fuse_consecutive_reduce_unsqueeze.onnx')
onnx.checker.check_model(model_def)
# 最適化パスを指定
passes = ['fuse_consecutive_reduce_unsqueeze']
optimized_model = optimizer.optimize(model_def, passes)
onnx.save(optimized_model, 'onnx/fuse_consecutive_reduce_unsqueeze_optimized.onnx')