LoginSignup
4
4

More than 3 years have passed since last update.

初めてのONNX

Posted at

ONNXの勉強を始めました。

まずはこちらのドキュメントを見ながらONNXファイルを作ってみました。
https://github.com/onnx/onnx/blob/master/docs/PythonAPIOverview.md

import onnx
from onnx import helper
from onnx import AttributeProto, TensorProto, GraphProto

X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 2])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 4])

node_def = helper.make_node(
    'Pad',
    ['X'],
    ['Y'],
    mode = 'constant',
    value = 1.5,
    pads = [0, 1, 0, 1]
)

graph_def = helper.make_graph(
    [node_def],
    'test-model',
    [X],
    [Y]
)

model_def = helper.make_model(
    graph_def,
    producer_name='onnx_example'
)

print('The model is:\n{}'.format(model_def))
onnx.checker.check_model(model_def)
print('The model is checked!')

onnx.save(model_def, 'first_onnx.onnx')

出来上がったモデルはnetronで確認できます。

first_onnx.png

何をやっているのかさっぱりわからないので、少しだけソース追ってみました。

make_tensor_value_info

https://github.com/onnx/onnx/blob/master/onnx/helper.py
データの型と形からValueInfoProtoを作る関数。

引数は5つ。

  • name ValueInfoProtoの名前
  • elem_type データの型
  • shape テンソルの形
  • doc_string
  • shape_denotation

shape_denotationは、shapeのチェックに使う引数で

len(shape_denotation) != len(shape)

このチェックに失敗すると、生成時にエラーを返してくれる。省略するかshape_denotation=Noneでチェックを行わない。

関数の中ではValueInfoProtoをインスタンス化して、引数を設定しているだけのよう。

shapeにNoneを指定するとNOPと同じになるが、protobufの表現と少し変わるのでそのままにしておいた方が良い。ふむ。

make_tensor_value_infoの中身。

def make_tensor_value_info(
        name,  # type: Text
        elem_type,  # type: int
        shape,  # type: Optional[Sequence[Union[Text, int]]]
        doc_string="",  # type: Text
        shape_denotation=None,  # type: Optional[List[Text]]
):  # type: (...) -> ValueInfoProto
    """Makes a ValueInfoProto based on the data type and shape."""
    value_info_proto = ValueInfoProto()
    value_info_proto.name = name
    if doc_string:
        value_info_proto.doc_string = doc_string

    tensor_type_proto = value_info_proto.type.tensor_type
    tensor_type_proto.elem_type = elem_type

    tensor_shape_proto = tensor_type_proto.shape

    if shape is not None:
        # You might think this is a no-op (extending a normal Python
        # list by [] certainly is), but protobuf lists work a little
        # differently; if a field is never set, it is omitted from the
        # resulting protobuf; a list that is explicitly set to be
        # empty will get an (empty) entry in the protobuf. This
        # difference is visible to our consumers, so make sure we emit
        # an empty shape!
        tensor_shape_proto.dim.extend([])

        if shape_denotation:
            if len(shape_denotation) != len(shape):
                raise ValueError(
                    'Invalid shape_denotation. '
                    'Must be of the same length as shape.')

        for i, d in enumerate(shape):
            dim = tensor_shape_proto.dim.add()
            if d is None:
                pass
            elif isinstance(d, integer_types):
                dim.dim_value = d
            elif isinstance(d, text_type):
                dim.dim_param = d
            else:
                raise ValueError(
                    'Invalid item in shape: {}. '
                    'Needs to of integer_types or text_type.'.format(d))

            if shape_denotation:
                dim.denotation = shape_denotation[i]

    return value_info_proto

elem_type

elem_typeはどうやらprotbufで定義されている感じ。

message TensorProto {
  enum DataType {
    UNDEFINED = 0;
    // Basic types.
    FLOAT = 1;   // float
    UINT8 = 2;   // uint8_t
    INT8 = 3;    // int8_t
    UINT16 = 4;  // uint16_t
    INT16 = 5;   // int16_t
    INT32 = 6;   // int32_t
    INT64 = 7;   // int64_t
    STRING = 8;  // string
    BOOL = 9;    // bool

    // IEEE754 half-precision floating-point format (16 bits wide).
    // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
    FLOAT16 = 10;

    DOUBLE = 11;
    UINT32 = 12;
    UINT64 = 13;
    COMPLEX64 = 14;     // complex with float32 real and imaginary components
    COMPLEX128 = 15;    // complex with float64 real and imaginary components

    // Non-IEEE floating-point format based on IEEE754 single-precision
    // floating-point number truncated to 16 bits.
    // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
    BFLOAT16 = 16;

    // Future extensions go here.
  }

複素数COMPLEX64は定義されているのに量子化は定義されていない気がする。
ドキュメントを見つけたので後で読む。
https://github.com/onnx/onnx/wiki/Quantization-Support-In-ONNX

make_node

NodeProtoを返す関数。

引数は7つだが、kwargsがあるので実際の数ははopに依存する。

  • op_type operatorの名前
  • inputs 入力の文字列をリストで指定
  • outputs 出力の文字列をリストで指定
  • name NodeProtoの名前
  • doc_string
  • domain ドメイン?
  • **kwargs キーワード引数

domainは良く分からないがデフォルトのままで良さそう。

op_typeに使えるoperatorの一覧はここ
https://github.com/onnx/onnx/blob/master/docs/Operators.md
https://github.com/onnx/onnx/blob/master/docs/Operators-ml.md

新しいoperatorの追加方法はここ
https://github.com/onnx/onnx/blob/master/docs/AddNewOp.md

上で使った"Pad"は、Paddingなのでここにパラメータの意味が記載されています。
https://github.com/onnx/onnx/blob/master/docs/Operators.md#Pad

make_protoの中身。

def make_node(
        op_type,  # type: Text
        inputs,  # type: Sequence[Text]
        outputs,  # type: Sequence[Text]
        name=None,  # type: Optional[Text]
        doc_string=None,  # type: Optional[Text]
        domain=None,  # type: Optional[Text]
        **kwargs  # type: Any
):  # type: (...) -> NodeProto
    """Construct a NodeProto.
    Arguments:
        op_type (string): The name of the operator to construct
        inputs (list of string): list of input names
        outputs (list of string): list of output names
        name (string, default None): optional unique identifier for NodeProto
        doc_string (string, default None): optional documentation string for NodeProto
        domain (string, default None): optional domain for NodeProto.
            If it's None, we will just use default domain (which is empty)
        **kwargs (dict): the attributes of the node.  The acceptable values
            are documented in :func:`make_attribute`.
    """

    node = NodeProto()
    node.op_type = op_type
    node.input.extend(inputs)
    node.output.extend(outputs)
    if name:
        node.name = name
    if doc_string:
        node.doc_string = doc_string
    if domain is not None:
        node.domain = domain
    if kwargs:
        node.attribute.extend(
            make_attribute(key, value)
            for key, value in sorted(kwargs.items()))
    return node

make_graph

GraphProtoを返す関数。

引数の説明

  • nodes NodeProtoをリストで指定
  • name グラフの名前
  • inputs 入力に使うValueInfoProtoをリストで指定
  • output 出力に使うValueInfoProtoをリストで指定
  • initializer 初期値をTensorProtoのリストで指定
  • doc_string
  • value_info ValueInfoProtoをリストで指定

中身はGraphProtoを生成して値を設定しているだけ。

def make_graph(
    nodes,  # type: Sequence[NodeProto]
    name,  # type: Text
    inputs,  # type: Sequence[ValueInfoProto]
    outputs,  # type: Sequence[ValueInfoProto]
    initializer=None,  # type: Optional[Sequence[TensorProto]]
    doc_string=None,  # type: Optional[Text]
    value_info=[],  # type: Sequence[ValueInfoProto]
):  # type: (...) -> GraphProto
    if initializer is None:
        initializer = []
    if value_info is None:
        value_info = []
    graph = GraphProto()
    graph.node.extend(nodes)
    graph.name = name
    graph.input.extend(inputs)
    graph.output.extend(outputs)
    graph.initializer.extend(initializer)
    graph.value_info.extend(value_info)
    if doc_string:
        graph.doc_string = doc_string
    return graph

make_model

ModelProtoを返す関数。
model.graph.CopyFrom(graph)が変換の本体っぽい。その後、キーワード引数にopset_importsがあるときは特別な処理をして、それ以外はsetattrをしているだけ。

def make_model(graph, **kwargs):  # type: (GraphProto, **Any) -> ModelProto
    model = ModelProto()
    # Touch model.ir_version so it is stored as the version from which it is
    # generated.
    model.ir_version = IR_VERSION
    model.graph.CopyFrom(graph)

    opset_imports = None  # type: Optional[Sequence[OperatorSetIdProto]]
    opset_imports = kwargs.pop('opset_imports', None)  # type: ignore
    if opset_imports is not None:
        model.opset_import.extend(opset_imports)
    else:
        # Default import
        imp = model.opset_import.add()
        imp.version = defs.onnx_opset_version()

    for k, v in kwargs.items():
        # TODO: Does this work with repeated fields?
        setattr(model, k, v)
    return model
4
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
4