LoginSignup
1
0

More than 1 year has passed since last update.

ONNXで eliminate_nop_monotone_argmax 最適化

Last updated at Posted at 2019-09-26

ONNXがサポートしている最適化 eliminate_nop_monotone_argmax を調べてみました。

ちょっとこれは名前だけでは何をしているのかわからないです。

対応するソースはこちら。

const std::unordered_set<NodeKind> monotone_node_no_axis_kind{kLog,
                                                              kExp,
                                                              kSqrt};

const std::unordered_set<NodeKind> monotone_node_axis_kind{kSoftmax,
                                                           kLogSoftmax};


static inline bool satisfies_monotone_condition(int64_t axis, Node* node) {
    if (monotone_node_no_axis_kind.find(node->kind()) !=
        monotone_node_no_axis_kind.end()) {
      return true;
    }
    if (monotone_node_axis_kind.find(node->kind()) !=
        monotone_node_axis_kind.end()) {
      if (node->hasAttribute(kaxis)) {
        return axis == node->i(kaxis);
      }
    }
    return false;
  }

  bool patternMatchPredicate(Node* node) override {
    if (node->kind() == kArgMax) {
      if (node->hasAttribute(kaxis)) {
        auto node_axis = node->i(kaxis);
        return node->inputs().size() == 1 &&
            satisfies_monotone_condition(node_axis, node->input()->node());
      }
    }
    return false;
  }

ソース見てもいまいちわからなかったのですが、ONNXのissueの方でやり取りを見つけました。

softmaxの後がArgmaxの時、softmaxがあってもなくても大小関係がかわらないのでsoftmaxは削除できるとのこと。argmaxと対になるargminが最適化の対象外なのは需要が無いからでしょうね。この最適化に対応するOperatorは、Log、Exp, Sqrt, Softmax, KLogSoftmaxの5つ。早速試してみました。

eliminate_nop_monotone_argmax.onnx (1).png

このように、argmaxの前にRelu、Softmax, Logを並べます。最適化をするとSoftmaxとLogが消えるはずです。Reluは値が全部の負の場合Argmaxの結果が変わってしまうので最適化できません。

passes = ['eliminate_nop_monotone_argmax']

optimized_model = optimizer.optimize(model_def, passes)

eliminate_nop_monotone_argmax_optimized.onnx (1).png

予想通り、Reluを残し、softmaxとlogが削除されました。

エッジ向けで最後のsoftmaxを手で取るのと同じですね。ただ、普通に入手できるONNXモデルだとsoftmaxが取れる状態のモデルってなさそう。

全ソースはこちら

import onnx
from onnx import helper
from onnx import TensorProto
from onnx import optimizer

X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [10])
Y = helper.make_tensor_value_info('Y', TensorProto.INT64, [3])


relu = helper.make_node(
    'Relu',
    inputs = ['X'],
    outputs = ['relu_out']
)

argmax_0 = helper.make_node(
    'ArgMax',
    inputs = ['relu_out'],
    outputs = ['argmax0_out'],
    axis=0
)

softmax = helper.make_node(
    'Softmax',
    inputs = ['X'],
    outputs = ['softmax_out'],
    axis = 0
)

argmax_1 = helper.make_node(
    'ArgMax',
    inputs = ['softmax_out'],
    outputs = ['argmax1_out'],
    axis = 0
)

log = helper.make_node(
    'Log',
    inputs = ['X'],
    outputs = ['log_out'],
    axis=0
)

argmax_2 = helper.make_node(
    'ArgMax',
    inputs = ['log_out'],
    outputs = ['argmax2_out'],
    axis = 0
)


concat = helper.make_node(
    'Concat',
    inputs = ['argmax0_out', 'argmax1_out', 'argmax2_out'],
    outputs = ['Y']
)

graph_def = helper.make_graph(
    [relu, argmax_0, softmax, argmax_1, log, argmax_2, concat],
    'test-model',
    [X],
    [Y]
)

model_def = helper.make_model(
    graph_def,
    producer_name='onnx_example'
)

onnx.save(model_def, 'onnx/eliminate_nop_monotone_argmax.onnx')

# 最適化パスを指定
passes = ['eliminate_nop_monotone_argmax']

optimized_model = optimizer.optimize(model_def, passes)
onnx.save(optimized_model, 'onnx/eliminate_nop_monotone_argmax_optimized.onnx')

1
0
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0