ONNXがサポートしている最適化 eliminate_nop_monotone_argmax を調べてみました。
ちょっとこれは名前だけでは何をしているのかわからないです。
対応するソースはこちら。
const std::unordered_set<NodeKind> monotone_node_no_axis_kind{kLog,
kExp,
kSqrt};
const std::unordered_set<NodeKind> monotone_node_axis_kind{kSoftmax,
kLogSoftmax};
static inline bool satisfies_monotone_condition(int64_t axis, Node* node) {
if (monotone_node_no_axis_kind.find(node->kind()) !=
monotone_node_no_axis_kind.end()) {
return true;
}
if (monotone_node_axis_kind.find(node->kind()) !=
monotone_node_axis_kind.end()) {
if (node->hasAttribute(kaxis)) {
return axis == node->i(kaxis);
}
}
return false;
}
bool patternMatchPredicate(Node* node) override {
if (node->kind() == kArgMax) {
if (node->hasAttribute(kaxis)) {
auto node_axis = node->i(kaxis);
return node->inputs().size() == 1 &&
satisfies_monotone_condition(node_axis, node->input()->node());
}
}
return false;
}
ソース見てもいまいちわからなかったのですが、ONNXのissueの方でやり取りを見つけました。
softmaxの後がArgmaxの時、softmaxがあってもなくても大小関係がかわらないのでsoftmaxは削除できるとのこと。argmaxと対になるargminが最適化の対象外なのは需要が無いからでしょうね。この最適化に対応するOperatorは、Log、Exp, Sqrt, Softmax, KLogSoftmaxの5つ。早速試してみました。
このように、argmaxの前にRelu、Softmax, Logを並べます。最適化をするとSoftmaxとLogが消えるはずです。Reluは値が全部の負の場合Argmaxの結果が変わってしまうので最適化できません。
passes = ['eliminate_nop_monotone_argmax']
optimized_model = optimizer.optimize(model_def, passes)
予想通り、Reluを残し、softmaxとlogが削除されました。
エッジ向けで最後のsoftmaxを手で取るのと同じですね。ただ、普通に入手できるONNXモデルだとsoftmaxが取れる状態のモデルってなさそう。
全ソースはこちら
import onnx
from onnx import helper
from onnx import TensorProto
from onnx import optimizer
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [10])
Y = helper.make_tensor_value_info('Y', TensorProto.INT64, [3])
relu = helper.make_node(
'Relu',
inputs = ['X'],
outputs = ['relu_out']
)
argmax_0 = helper.make_node(
'ArgMax',
inputs = ['relu_out'],
outputs = ['argmax0_out'],
axis=0
)
softmax = helper.make_node(
'Softmax',
inputs = ['X'],
outputs = ['softmax_out'],
axis = 0
)
argmax_1 = helper.make_node(
'ArgMax',
inputs = ['softmax_out'],
outputs = ['argmax1_out'],
axis = 0
)
log = helper.make_node(
'Log',
inputs = ['X'],
outputs = ['log_out'],
axis=0
)
argmax_2 = helper.make_node(
'ArgMax',
inputs = ['log_out'],
outputs = ['argmax2_out'],
axis = 0
)
concat = helper.make_node(
'Concat',
inputs = ['argmax0_out', 'argmax1_out', 'argmax2_out'],
outputs = ['Y']
)
graph_def = helper.make_graph(
[relu, argmax_0, softmax, argmax_1, log, argmax_2, concat],
'test-model',
[X],
[Y]
)
model_def = helper.make_model(
graph_def,
producer_name='onnx_example'
)
onnx.save(model_def, 'onnx/eliminate_nop_monotone_argmax.onnx')
# 最適化パスを指定
passes = ['eliminate_nop_monotone_argmax']
optimized_model = optimizer.optimize(model_def, passes)
onnx.save(optimized_model, 'onnx/eliminate_nop_monotone_argmax_optimized.onnx')