LoginSignup
1
0

More than 3 years have passed since last update.

ONNXで fuse_transpose_into_gemm 最適化

Posted at

ONNXがサポートしている最適 fuse_transpose_into_gemm を調べてみました。

最適化をしている場所はここ。
https://github.com/onnx/onnx/blob/master/onnx/optimizer/passes/fuse_transpose_into_gemm.h

bool patternMatchPredicate(Node* node) override {
    return node->kind() == kGemm;
  }
  bool runTransform(Node* n, Graph&, NodeDestroyType& destroy_current)
      override {
    const std::vector<int64_t> simple_trans_perm({1, 0});
    destroy_current = NodeDestroyType::DestroyZero;
    bool ret_val = false;
    for (size_t i : {0, 1}) {
      auto inp = n->inputs()[i];
      auto trans = i == 0 ? ktransA : ktransB;
      if (inp->node()->kind() == kTranspose &&
          inp->node()->is(kperm) == simple_trans_perm) {
        n->replaceInput(i, inp->node()->input());
        n->i_(trans, n->hasAttribute(trans) ? !n->i(trans) : 1);
        if (inp->uses().size() == 0) {
          inp->node()->destroy();
          ret_val = true;
        }
      }
    }
    return ret_val;
  }
};

gemmの入力がtransposeで、transposeのpermが[0, 1]の時、transposeをgemmにfuseするようです。

グラフ作って試してみました。

transpose = helper.make_node(
    'Transpose',
    perm = [1, 0],
    inputs = ['X1'],
    outputs = ['transpose_out'],
)

gemm = helper.make_node(
    'Gemm',
    ['transpose_out', 'X2'],
    ['Y'],
    transA=1
)

最初のTransposeが単純な転置、続くgemmの入力でも転置をいれています。
最適化前のグラフがこう。

fuse_transpose_into_gemm.onnx.png

passにfuse_transpose_into_gemmを設定して、optimizer.optimizeを実行します。

passes = ['fuse_transpose_into_gemm']

optimized_model = optimizer.optimize(model_def, passes)

最適化を実行するとTransposeが削除されます。
fuse_transpose_into_gemm_optimized.onnx.png

gemmの情報を見てみると、上で設定したtransAが0(転置無し)になっています。

op_type: "Gemm"
attribute {
  name: "transA"
  i: 0
  type: INT
}

上手くいってますね。

全ソースです。

import numpy as np
import onnx
from onnx import helper, numpy_helper
from onnx import TensorProto
from onnx import optimizer

X1 = helper.make_tensor_value_info('X1', TensorProto.FLOAT, [5, 5])
X2 = helper.make_tensor_value_info('X2', TensorProto.FLOAT, [5, 5])

Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [5, 5])

transpose = helper.make_node(
    'Transpose',
    perm = [1, 0],
    inputs = ['X1'],
    outputs = ['transpose_out'],
)

gemm = helper.make_node(
    'Gemm',
    ['transpose_out', 'X2'],
    ['Y'],
    transA=1
)

graph_def = helper.make_graph(
    [transpose, gemm],
    'test-model',
    [X1, X2],
    [Y]
)

model_def = helper.make_model(
    graph_def,
    producer_name='onnx_example'
)

onnx.save(model_def, 'onnx/fuse_transpose_into_gemm.onnx')

onnx.checker.check_model(model_def)

# 最適化パスを指定
passes = ['fuse_transpose_into_gemm']

optimized_model = optimizer.optimize(model_def, passes)
onnx.save(optimized_model, 'onnx/fuse_transpose_into_gemm_optimized.onnx')

node = optimized_model.graph.node[0]
print(node)

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0