LoginSignup
0
0

More than 3 years have passed since last update.

ONNXで split 最適化

Posted at

ONNXがサポートしている最適化 split について調べました。

最適化をしている場所はここ。
https://github.com/onnx/onnx/blob/master/onnx/optimizer/passes/split.h

コメントを読んでいくと、推論時にグラフの中で定数ととして扱えるものを別のsubgraphに分離する最適化のようです。面白そう。

早速試してみました。add0の入力が定数になるようなグラフを作成します。

split.onnx.png

左上のAddが定数のみの足し算です。グラフの計算を実行する時に毎回計算する必要が無い部分です。真ん中のAddは、入力Xによって結果が変わるので毎回計算必要です。

最適化は、定数として扱える部分を抜き出すのにsplit_init、推論時に計算が必要な部分を抜き出すのにsplit_predictを指定して、2回呼び出しました。使い方は合ってるはす・・・。

# 最適化パスを指定
passes = ['split_init', ]

optimized_model_init = optimizer.optimize(model_def, passes)
onnx.save(optimized_model_init, 'onnx/split_init.onnx')

passes = ['split_predict', ]

optimized_model_pred = optimizer.optimize(model_def, passes)
onnx.save(optimized_model_pred, 'onnx/split_pred.onnx')

最適化後の初期化部分がこっち。定数同士の加算部分だけを別グラフにし、add0_outという名前で出力しています。初期値扱いで、一度計算すれば再計算は不要です。

split_init.onnx.png

推論の計算部分がこっち。上のグラフで計算したadd0_outの結果を使って、全体の計算を行っています。再計算が不要な部分は、別グラフに分離してあるので、定数などは気にせず全部の計算を行えばOKです。

split_pred.onnx.png

上手く分離されてますね。面白い。

全ソースです。

import numpy as np
import onnx
from onnx import helper, numpy_helper
from onnx import TensorProto
from onnx import optimizer

X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1])

A = np.ones((1,), dtype=np.float32)
B = np.ones((1,), dtype=np.float32)

const0 = onnx.helper.make_node(
    'Constant',
    inputs=[],
    outputs=['const_0'],
    value=onnx.helper.make_tensor(
        name='const_tensor',
        data_type=onnx.TensorProto.FLOAT,
        dims=A.shape,
        vals=A.flatten().astype(float),
    ),
)

const1 = onnx.helper.make_node(
    'Constant',
    inputs=[],
    outputs=['const_1'],
    value=onnx.helper.make_tensor(
        name='const_tensor',
        data_type=onnx.TensorProto.FLOAT,
        dims=B.shape,
        vals=B.flatten().astype(float),
    ),
)


add0 = helper.make_node(
    'Add',
    inputs = ['const_0', "const_1"],
    outputs = ['add0_out'],
)

add1 = helper.make_node(
    'Add',
    inputs=['X', 'add0_out'],
    outputs=['Y']
)

gemm = helper.make_node(
    'Gemm',
    ['X'],
    ['Y'],
    transA=1
)

graph_def = helper.make_graph(
    [const0, const1, add0, add1],
    'test-model',
    [X],
    [Y],
    initializer=[]

)

model_def = helper.make_model(
    graph_def,
    producer_name='onnx_example'
)

onnx.save(model_def, 'onnx/split.onnx')

onnx.checker.check_model(model_def)

# 最適化パスを指定
passes = ['split_init', ]

optimized_model_init = optimizer.optimize(model_def, passes)
onnx.save(optimized_model_init, 'onnx/split_init.onnx')

passes = ['split_predict', ]

optimized_model_pred = optimizer.optimize(model_def, passes)
onnx.save(optimized_model_pred, 'onnx/split_pred.onnx')
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0