ONNXがサポートしている最適化eliminate_deadend を調べてみました。
字面から接続先がない何か(deadend)を消してくれそうです。
ソースで検索するとここがひっかります。
https://github.com/onnx/onnx/blob/master/onnx/optimizer/passes/eliminate_deadend.h
unsigned int EliminateDead(Graph& graph) {
unsigned int nodes_removed = 0;
auto nodes = graph.nodes().reverse();
for (auto it = nodes.begin(); it != nodes.end(); it++) {
auto node = *it;
if (!node->hasUses()) {
nodes_removed++;
it.destroyCurrent();
}
}
return nodes_removed;
}
ソースを見ると、グラフのノードを逆順にして、使っていないnodeを削除しているようです。大事なのはノード単位の最適化らしいということ。
実際にやってみましょう。
こういうONNXを作ります。
Yはグラフの出力に接続されていますが、Reluの出力はどこにもつながっていません。
graph_def = helper.make_graph(
[node_def, relu0_op],
'test-model',
[X],
[Y]
)
ONNXで最適化を行うためには、最適化のpassを指定してoptimizer.optimizeを呼び出します。今回試したいのは、eliminate_deadendなので、このようにしました。
# 最適化パスを指定
passes = ['eliminate_deadend']
optimized_model = optimizer.optimize(model_def, passes)
onnx.save(optimized_model, 'onnx/eliminate_deadend_optimized.onnx')
実行してみると、見事にReluが削除されました。やったね。
全ソースはこちら
import onnx
from onnx import helper
from onnx import TensorProto
from onnx import optimizer
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 2])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 4])
node_def = helper.make_node(
'Pad',
['X'],
['Y'],
mode = 'constant',
value = 1.5,
pads = [0, 1, 0, 1]
)
relu0_op = helper.make_node(
'Relu',
inputs = ['Y'],
outputs = ['relu0_out']
)
graph_def = helper.make_graph(
[node_def, relu0_op],
'test-model',
[X],
[Y]
)
model_def = helper.make_model(
graph_def,
producer_name='onnx_example'
)
onnx.save(model_def, 'onnx/eliminate_deadend.onnx')
# 最適化パスを指定
passes = ['eliminate_deadend']
optimized_model = optimizer.optimize(model_def, passes)
onnx.save(optimized_model, 'onnx/eliminate_deadend_optimized.onnx')