LoginSignup
0
0

More than 3 years have passed since last update.

ONNXで fuse_consecutive_log_softmax 最適化

Posted at

ONNXがサポートしている最適化 fuse_consecutive_log_softmax を調べてみました。

最適化をしている箇所はここ。log -> softmaxではなく、softmax -> logの並びの時に最適化が有効。

最適化前のグラフ

fuse_consecutive_log_softmax.onnx.png

passにfuse_consecutive_log_softmaxを指定して、optimizer.optimizeを呼び出します。

passes = ['fuse_consecutive_log_softmax']

optimized_model = optimizer.optimize(model_def, passes)

最適化後のグラフ

fuse_consecutive_log_softmax_optimized.onnx.png

上手くLogSoftmaxに置き換わっています。

全ソース。


import onnx
from onnx import helper
from onnx import TensorProto
from onnx import optimizer

X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 4])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 4])

softmax = helper.make_node(
    'Softmax',
    axis = 1,
    inputs = ['X'],
    outputs = ['softmax_out'],
)


log = helper.make_node(
    'Log',
    inputs = ['softmax_out'],
    outputs = ['Y'],
)


graph_def = helper.make_graph(
    [softmax, log],
    'test-model',
    [X],
    [Y]
)

model_def = helper.make_model(
    graph_def,
    producer_name='onnx_example'
)

onnx.save(model_def, 'onnx/fuse_consecutive_log_softmax.onnx')

onnx.checker.check_model(model_def)

# 最適化パスを指定
passes = ['fuse_consecutive_log_softmax']

optimized_model = optimizer.optimize(model_def, passes)
onnx.save(optimized_model, 'onnx/fuse_consecutive_log_softmax_optimized.onnx')

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0