LoginSignup
0
0

More than 3 years have passed since last update.

ONNXで eliminate_nop_dropout 最適化

Posted at

ONNXがサポートしている最適化 eliminate_nop_dropout を調べてみました。

対応するソースはおそらくここ。
https://github.com/onnx/onnx/blob/master/onnx/optimizer/passes/eliminate_nop_dropout.h

  bool patternMatchPredicate(Node* node) override {
    return (node->kind() == kDropout && node->hasAttribute(kratio)) &&
        node->f(kratio) == 0.0;
  }

  bool runTransform(Node* node, Graph&, NodeDestroyType& destroy_current)
      override {
    // Don't assume that theres only one output.
    for (size_t i = 0; i < node->outputs().size(); ++i) {
      node->outputs()[i]->replaceAllUsesWith(node->input());
    }
    destroy_current = NodeDestroyType::DestroyOne;
    return true;
  }

連続するnopとdropoutではなく、nopと同じ意味のdropout(ratio=0.0)を削除しています。なるほど。

さっそくテスト用のグラフを作って試しました。

eliminate_nop_dropout.onnx.png

左のDropoutがratio=0.0, 右のDropoutがratio=0.5に設定してあります。

最適化パスにeliminate_nop_dropoutを追加して、optimizer.optimizeを呼び出して最適化します。

# 最適化パスを指定
passes = ['eliminate_nop_dropout']

optimized_model = optimizer.optimize(model_def, passes)

eliminate_nop_dropout_optimized.onnx.png

予想通り、左側のdropoutのみ削除されました。eliminate_nop_dropout は、ratioが0.0に設定されたnopと同じ(何もしない)dropoutを削除します。

全ソースです。

import onnx
from onnx import helper
from onnx import TensorProto
from onnx import optimizer

X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 2])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 4])

node_def = helper.make_node(
    'Pad',
    ['X'],
    ['pad_out'],
    mode = 'constant',
    value = 1.5,
)

dropout_00 = helper.make_node(
    'Dropout',
    inputs = ['pad_out'],
    outputs = ['dropout_00_out'],
    ratio = 0.0
)

dropout_05 = helper.make_node(
    'Dropout',
    inputs = ['pad_out'],
    outputs = ['dropout_05_out'],
    ratio = 0.5
)

concat = helper.make_node(
    'Concat',
    inputs = ['dropout_00_out', 'dropout_05_out'],
    outputs = ['Y']
)


graph_def = helper.make_graph(
    [node_def, dropout_00, dropout_05, concat],
    'test-model',
    [X],
    [Y]
)

model_def = helper.make_model(
    graph_def,
    producer_name='onnx_example'
)

onnx.save(model_def, 'onnx/eliminate_nop_dropout.onnx')

# 最適化パスを指定
passes = ['eliminate_nop_dropout']

optimized_model = optimizer.optimize(model_def, passes)
onnx.save(optimized_model, 'onnx/eliminate_nop_dropout_optimized.onnx')


0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0