LoginSignup
0
0

More than 3 years have passed since last update.

ONNXで fuse_consecutive_reduce_unsqueeze 最適化

Posted at

ONNXがサポートしている最適化 fuse_consecutive_reduce_unsqueeze を調べてみました。

最適化をしているところはここ。
https://github.com/onnx/onnx/blob/master/onnx/optimizer/passes/fuse_consecutive_reduce_unsqueeze.h

bool patternMatchPredicate(Node* node) override {
    // check that the current node is of type Unsqueeze and has defined axes
    bool cur_node_check =
        node->kind() == kUnsqueeze && node->hasAttribute(kaxes);
    if (cur_node_check) {
      Node* prev_node = node->input()->node();
      // check that the previous node a reduction operator and has defined
      // axes/keepdims
      bool reduction_node_check = reduction_operators.find(prev_node->kind()) !=
              reduction_operators.end() &&
          prev_node->hasAttribute(kaxes) && prev_node->hasAttribute(kkeepdims);
      if (reduction_node_check) {
        // insure that keepdims is set to false currently
        return prev_node->i(kkeepdims) == 0 && node->is(kaxes) == prev_node->is(kaxes);
      }
    }
    return false;
  }

最適化できるかどうかの判断が ちょっと難しい

  • Reduction -> Unsqueezeの並びであること
  • ReductionのoperatorでaxisとkeepdimsのAttributeを持っている事
  • Reductionのkeepdimsが0で、Reductionのaxisとsqueezeのaxesが同じこと

その条件を満たすグラフを作りました。

fuse_consecutive_reduce_unsqueeze.onnx.png

passにfuse_consecutive_reduce_unsqueezeを指定して、optimizer.optimizeを呼び出します。

passes = ['fuse_consecutive_reduce_unsqueeze']

optimized_model = optimizer.optimize(model_def, passes)

最適化後のグラフ

fuse_consecutive_reduce_unsqueeze_optimized.onnx.png

上手く消えてますね。

全ソース

import onnx
from onnx import helper
from onnx import TensorProto
from onnx import optimizer

X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 4])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 4])

reducemax = helper.make_node(
    'ReduceMax',
    keepdims = 0,
    axes = [1],
    inputs = ['X'],
    outputs = ['reducemax_out'],
)


unsqueeze = helper.make_node(
    'Unsqueeze',
    axes = [1],
    inputs = ['reducemax_out'],
    outputs = ['Y'],
)


graph_def = helper.make_graph(
    [reducemax, unsqueeze],
    'test-model',
    [X],
    [Y]
)

model_def = helper.make_model(
    graph_def,
    producer_name='onnx_example'
)

onnx.save(model_def, 'onnx/fuse_consecutive_reduce_unsqueeze.onnx')

onnx.checker.check_model(model_def)

# 最適化パスを指定
passes = ['fuse_consecutive_reduce_unsqueeze']

optimized_model = optimizer.optimize(model_def, passes)
onnx.save(optimized_model, 'onnx/fuse_consecutive_reduce_unsqueeze_optimized.onnx')

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0