ONNXの勉強を始めました。
まずはこちらのドキュメントを見ながらONNXファイルを作ってみました。
https://github.com/onnx/onnx/blob/master/docs/PythonAPIOverview.md
import onnx
from onnx import helper
from onnx import AttributeProto, TensorProto, GraphProto
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 2])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 4])
node_def = helper.make_node(
'Pad',
['X'],
['Y'],
mode = 'constant',
value = 1.5,
pads = [0, 1, 0, 1]
)
graph_def = helper.make_graph(
[node_def],
'test-model',
[X],
[Y]
)
model_def = helper.make_model(
graph_def,
producer_name='onnx_example'
)
print('The model is:\n{}'.format(model_def))
onnx.checker.check_model(model_def)
print('The model is checked!')
onnx.save(model_def, 'first_onnx.onnx')
出来上がったモデルはnetronで確認できます。
何をやっているのかさっぱりわからないので、少しだけソース追ってみました。
make_tensor_value_info
https://github.com/onnx/onnx/blob/master/onnx/helper.py
データの型と形からValueInfoProtoを作る関数。
引数は5つ。
- name ValueInfoProtoの名前
- elem_type データの型
- shape テンソルの形
- doc_string
- shape_denotation
shape_denotationは、shapeのチェックに使う引数で
len(shape_denotation) != len(shape)
このチェックに失敗すると、生成時にエラーを返してくれる。省略するかshape_denotation=Noneでチェックを行わない。
関数の中ではValueInfoProtoをインスタンス化して、引数を設定しているだけのよう。
shapeにNoneを指定するとNOPと同じになるが、protobufの表現と少し変わるのでそのままにしておいた方が良い。ふむ。
make_tensor_value_infoの中身。
def make_tensor_value_info(
name, # type: Text
elem_type, # type: int
shape, # type: Optional[Sequence[Union[Text, int]]]
doc_string="", # type: Text
shape_denotation=None, # type: Optional[List[Text]]
): # type: (...) -> ValueInfoProto
"""Makes a ValueInfoProto based on the data type and shape."""
value_info_proto = ValueInfoProto()
value_info_proto.name = name
if doc_string:
value_info_proto.doc_string = doc_string
tensor_type_proto = value_info_proto.type.tensor_type
tensor_type_proto.elem_type = elem_type
tensor_shape_proto = tensor_type_proto.shape
if shape is not None:
# You might think this is a no-op (extending a normal Python
# list by [] certainly is), but protobuf lists work a little
# differently; if a field is never set, it is omitted from the
# resulting protobuf; a list that is explicitly set to be
# empty will get an (empty) entry in the protobuf. This
# difference is visible to our consumers, so make sure we emit
# an empty shape!
tensor_shape_proto.dim.extend([])
if shape_denotation:
if len(shape_denotation) != len(shape):
raise ValueError(
'Invalid shape_denotation. '
'Must be of the same length as shape.')
for i, d in enumerate(shape):
dim = tensor_shape_proto.dim.add()
if d is None:
pass
elif isinstance(d, integer_types):
dim.dim_value = d
elif isinstance(d, text_type):
dim.dim_param = d
else:
raise ValueError(
'Invalid item in shape: {}. '
'Needs to of integer_types or text_type.'.format(d))
if shape_denotation:
dim.denotation = shape_denotation[i]
return value_info_proto
elem_type
elem_typeはどうやらprotbufで定義されている感じ。
message TensorProto {
enum DataType {
UNDEFINED = 0;
// Basic types.
FLOAT = 1; // float
UINT8 = 2; // uint8_t
INT8 = 3; // int8_t
UINT16 = 4; // uint16_t
INT16 = 5; // int16_t
INT32 = 6; // int32_t
INT64 = 7; // int64_t
STRING = 8; // string
BOOL = 9; // bool
// IEEE754 half-precision floating-point format (16 bits wide).
// This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
FLOAT16 = 10;
DOUBLE = 11;
UINT32 = 12;
UINT64 = 13;
COMPLEX64 = 14; // complex with float32 real and imaginary components
COMPLEX128 = 15; // complex with float64 real and imaginary components
// Non-IEEE floating-point format based on IEEE754 single-precision
// floating-point number truncated to 16 bits.
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
BFLOAT16 = 16;
// Future extensions go here.
}
複素数COMPLEX64は定義されているのに量子化は定義されていない気がする。
ドキュメントを見つけたので後で読む。
https://github.com/onnx/onnx/wiki/Quantization-Support-In-ONNX
make_node
NodeProtoを返す関数。
引数は7つだが、kwargsがあるので実際の数ははopに依存する。
- op_type operatorの名前
- inputs 入力の文字列をリストで指定
- outputs 出力の文字列をリストで指定
- name NodeProtoの名前
- doc_string
- domain ドメイン?
- **kwargs キーワード引数
domainは良く分からないがデフォルトのままで良さそう。
op_typeに使えるoperatorの一覧はここ
https://github.com/onnx/onnx/blob/master/docs/Operators.md
https://github.com/onnx/onnx/blob/master/docs/Operators-ml.md
新しいoperatorの追加方法はここ
https://github.com/onnx/onnx/blob/master/docs/AddNewOp.md
上で使った"Pad"は、Paddingなのでここにパラメータの意味が記載されています。
https://github.com/onnx/onnx/blob/master/docs/Operators.md#Pad
make_protoの中身。
def make_node(
op_type, # type: Text
inputs, # type: Sequence[Text]
outputs, # type: Sequence[Text]
name=None, # type: Optional[Text]
doc_string=None, # type: Optional[Text]
domain=None, # type: Optional[Text]
**kwargs # type: Any
): # type: (...) -> NodeProto
"""Construct a NodeProto.
Arguments:
op_type (string): The name of the operator to construct
inputs (list of string): list of input names
outputs (list of string): list of output names
name (string, default None): optional unique identifier for NodeProto
doc_string (string, default None): optional documentation string for NodeProto
domain (string, default None): optional domain for NodeProto.
If it's None, we will just use default domain (which is empty)
**kwargs (dict): the attributes of the node. The acceptable values
are documented in :func:`make_attribute`.
"""
node = NodeProto()
node.op_type = op_type
node.input.extend(inputs)
node.output.extend(outputs)
if name:
node.name = name
if doc_string:
node.doc_string = doc_string
if domain is not None:
node.domain = domain
if kwargs:
node.attribute.extend(
make_attribute(key, value)
for key, value in sorted(kwargs.items()))
return node
make_graph
GraphProtoを返す関数。
引数の説明
- nodes NodeProtoをリストで指定
- name グラフの名前
- inputs 入力に使うValueInfoProtoをリストで指定
- output 出力に使うValueInfoProtoをリストで指定
- initializer 初期値をTensorProtoのリストで指定
- doc_string
- value_info ValueInfoProtoをリストで指定
中身はGraphProtoを生成して値を設定しているだけ。
def make_graph(
nodes, # type: Sequence[NodeProto]
name, # type: Text
inputs, # type: Sequence[ValueInfoProto]
outputs, # type: Sequence[ValueInfoProto]
initializer=None, # type: Optional[Sequence[TensorProto]]
doc_string=None, # type: Optional[Text]
value_info=[], # type: Sequence[ValueInfoProto]
): # type: (...) -> GraphProto
if initializer is None:
initializer = []
if value_info is None:
value_info = []
graph = GraphProto()
graph.node.extend(nodes)
graph.name = name
graph.input.extend(inputs)
graph.output.extend(outputs)
graph.initializer.extend(initializer)
graph.value_info.extend(value_info)
if doc_string:
graph.doc_string = doc_string
return graph
make_model
ModelProtoを返す関数。
model.graph.CopyFrom(graph)が変換の本体っぽい。その後、キーワード引数にopset_importsがあるときは特別な処理をして、それ以外はsetattrをしているだけ。
def make_model(graph, **kwargs): # type: (GraphProto, **Any) -> ModelProto
model = ModelProto()
# Touch model.ir_version so it is stored as the version from which it is
# generated.
model.ir_version = IR_VERSION
model.graph.CopyFrom(graph)
opset_imports = None # type: Optional[Sequence[OperatorSetIdProto]]
opset_imports = kwargs.pop('opset_imports', None) # type: ignore
if opset_imports is not None:
model.opset_import.extend(opset_imports)
else:
# Default import
imp = model.opset_import.add()
imp.version = defs.onnx_opset_version()
for k, v in kwargs.items():
# TODO: Does this work with repeated fields?
setattr(model, k, v)
return model