91
84

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

ONNX形式のモデルを扱う

Last updated at Posted at 2018-04-18

本記事では、ONNX形式のモデルが登場した背景やそのImporter/Exporterのサポート状況、そしてONNX形式のモデルそのものをPythonから扱う方法について説明します。
ONNX形式のモデルを各Deep Learning用のフレームワークにImportして実行する方法については対象外のため、各フレームワークのチュートリアルを各自で参照してください。

背景

世の中には数多くのDeep Learning用のフレームワークが存在しますが、各フレームワークが扱うことのできるモデルはフレームワークによって異なります。このため、これまでは特定のフレームワークを使って作成した学習済みモデルは、同じフレームワークを使って推論する必要がありました。
例えば、TensorFlowで学習したモデルを使って推論する場合、学習に利用したTensorFlowを使って推論するという選択肢しかありませんでした。つまり、TensorFlowを使って学習したモデルをPyTorchで推論するということはこれまでできませんでした。様々な制約下に置かれるハードウェアで推論演算する需要が高まっている中で、学習と推論を同じフレームワークで行わなければならないのは、フレームワークを動かすために必要なソフトウェアのインストールが必要になるなど、ハードウェア側としてはうれしくありません。

ONNXとは

このような状況の中、ONNXと呼ばれるニューラルネットワークのモデルを定義するためのオープンフォーマットが登場しました。ONNXは当初、NNVM/TVMをはじめとした、推論向けのグラフコンパイラやDSLにおけるImporterのサポートから始まり、最近では各フレームワークのImporter/Exporterのサポートが進んでいます。また最近では Amazon、Microsoft、FaceBookが中心になってWGが立ち上がり、今後ますますONNXの普及が進んでいくと考えられます。

ONNXのサポート状況

以下に示すように、世の中の全てのフレームワークが、ONNXのImporter/Exporterをサポートしているわけではありません。しかし、ONNXが普及されるにつれ、Importer/Exporterが各フレームワークでサポートされるのは時間の問題かと思います。本記事執筆時点でもONNXのImporter/Exporterの開発は進んでいます。

フレームワーク Importer Exporter
Caffe × ×
TensorFlow(※1)
MXNet ×
CNTK
Chainer
PyTorch ×
Caffe2(※2)
○:サポート、△:一部サポート(Experimental)、×:未対応

※1 外部プロジェクト(ONNX Organization)でサポート
※2 2018/4月頭に、Caffe2のソースコードはPyTorchプロジェクトで管理されることになり、実質PyTorchの一機能としてCaffe2が提供されるようになりました

各フレームワークのONNX Importer/Exporterサポート状況

PythonからONNX形式のモデルを扱う

さて本題である、PythonからONNX形式のモデルを読み込む方法とONNX形式のモデルを作る方法を説明したいと思います。

環境構築

Anacondaのインストール

ONNXは、Anacondaのインストールが必要です。
Anacondaの公式ホームページ からAnacondaをインストールします。

ONNXのインストール

ONNXの公式ホームページ を参考に、ONNXのPythonモジュールとそれに依存するパッケージをインストールします。

ONNXモデルの読み込み

最初に、ONNX形式で作成されたモデルを読み込む方法を説明します。

ONNX形式のモデルの取得

読み込むONNX形式のモデルは、各フレームワークに備わっているExporter機能を利用して自分で作成してもよいですが、ここでは世の中に公開されているモデルを利用します。
ONNX形式のモデルは、GitHubプロジェクト onnx/models から取得することができます1
ここでは、上記プロジェクトで最も古くから公開されているモデルの1つである VGG19 を使います。

ONNX形式のモデルを読み込むプログラム

ONNX形式のモデルを読み込むPythonプログラム例を示します。このプログラムは、VGG19のONNX形式のモデルを読み込み、読み込んだモデル(グラフ)を構成するノードと入力データ、出力データの一覧を標準出力に出力します。

import onnx

model_path = "vgg19/model.onnx"


def main():
    # ONNX形式のモデルを読み込む
    model = onnx.load(model_path)

    # モデル(グラフ)を構成するノードを全て出力する
    print("====== Nodes ======")
    for i, node in enumerate(model.graph.node):
        print("[Node #{}]".format(i))
        print(node)

    # モデルの入力データ一覧を出力する
    print("====== Inputs ======")
    for i, input in enumerate(model.graph.input):
        print("[Input #{}]".format(i))
        print(input)

    # モデルの出力データ一覧を出力する
    print("====== Outputs ======")
    for i, output in enumerate(model.graph.output):
        print("[Output #{}]".format(i))
        print(output)


if __name__ == "__main__":
    main()

実行結果

上記のプログラムを実行した結果を以下に示します。結果を全てここに記載すると非常に長い記事になってしまいますので、ノードと入力データについては最初の要素のみ記載し、その他の要素は省略しました。
実行結果を見るとわかると思いますが、VGG19を構成する演算ノードや入出力に関する情報が出力されています。

====== Nodes ======
[Node #0]
input: "data_0"
input: "conv1_1_w_0"
input: "conv1_1_b_0"
output: "conv1_1_1"
name: ""
op_type: "Conv"
attribute {
  name: "kernel_shape"
  ints: 3
  ints: 3
  type: INTS
}
attribute {
  name: "strides"
  ints: 1
  ints: 1
  type: INTS
}
attribute {
  name: "pads"
  ints: 1
  ints: 1
  ints: 1
  ints: 1
  type: INTS
}

[Node #1]
...(略)...

====== Inputs ======
[Input #0]
name: "conv1_1_w_0"
type {
  tensor_type {
    elem_type: FLOAT
    shape {
      dim {
        dim_value: 64
      }
      dim {
        dim_value: 3
      }
      dim {
        dim_value: 3
      }
      dim {
        dim_value: 3
      }
    }
  }
}

[Input #1]
...(略)...

====== Outputs ======
[Output #0]
name: "prob_1"
type {
  tensor_type {
    elem_type: FLOAT
    shape {
      dim {
        dim_value: 1
      }
      dim {
        dim_value: 1000
      }
    }
  }
}

ONNX形式のモデル作成

続いて、PythonプログラムでONNXのモデルを作る方法を説明します。

作成するモデルの決定

ここでは、以下のようなグラフ構造を持ったモデルを作成します。

入力Tensor(Input1、Input2)を2つ受け取り、1つ(Input1)はReLU演算の入力とします。
もう一方の入力Tensor(Input2)は、ReLU演算の出力結果と合わせてAdd演算の入力とします。

PythonからONNX形式のモデルを作成

上記のグラフ構造を持つ、ONNX形式のモデルを作成するプログラム例を示します。

import onnx
import onnx.helper as oh
from onnx import checker

# モデルの出力ファイル名
out_path = "custom_model.onnx"


def main():
    # 入出力Tensor、および中間で使用するTensorを作成
    in_tensor = [
        oh.make_tensor_value_info("Input1", onnx.TensorProto.FLOAT, [3, 2]),
        oh.make_tensor_value_info("Input2", onnx.TensorProto.FLOAT, [3, 2]),
    ]
    im_tensor = [
        oh.make_tensor_value_info("Result ReLU", onnx.TensorProto.FLOAT, [3, 2]),
    ]
    out_tensor = [
        oh.make_tensor_value_info("Output", onnx.TensorProto.FLOAT, [3, 2]),
    ]

    # 計算ノードを作成
    # 本プログラムでは、ReLU演算とAdd演算のノードを作成
    nodes = []
    nodes.append(oh.make_node("Relu", ["Input1"], ["Result ReLU"]))
    nodes.append(oh.make_node("Add", ["Result ReLU", "Input2"], ["Output"]))

    # グラフを作成
    # 以下に示すグラフが作成される
    # [Input1] - Relu - Add - [Output]
    #                 /
    #         [Input2]
    graph = oh.make_graph(nodes, "Test Graph", in_tensor, out_tensor)

    # グラフが正しく作成できていることを確認
    checker.check_graph(graph)

    # モデルを構築
    model = oh.make_model(graph, producer_name="AtuNuka", producer_version="0.1")

    # モデルが正しく作成できていることを確認
    checker.check_model(model)

    # 作成したモデルをファイルへ保存(バイナリ形式)
    with open(out_path, "wb") as f:
        f.write(model.SerializeToString())

    # バイナリ形式では内容を確認できないため、テキスト形式でも保存
    with open(out_path + ".txt", "w") as f:
        print(model, file=f)


if __name__ == "__main__":
    main()

作成されたモデルの確認

プログラムを実行すると、プログラムを実行したディレクトリにバイナリファイル custom_model.onnx と、テキストファイル custom_model.onnx.txt が作成されます。
実際にONNX形式のモデルをDeployする時には、バイナリファイルを利用することになります。しかしここでは、可読性の高いテキストファイルも一緒に出力し、モデルがきちんと作成できていることを確認します。テキストファイルのフォーマットについてここでは説明しませんが、graphnodeinputoutput などのモデルを構築するための情報が記述され、ONNX形式のモデルが正しく作れていそうだと想像できます。

ir_version: 3
producer_name: "AtuNuka"
producer_version: "0.1"
graph {
  node {
    input: "Input1"
    output: "Result ReLU"
    op_type: "Relu"
  }
  node {
    input: "Result ReLU"
    input: "Input2"
    output: "Output"
    op_type: "Add"
  }
  name: "Test Graph"
  input {
    name: "Input1"
    type {
      tensor_type {
        elem_type: FLOAT
        shape {
          dim {
            dim_value: 3
          }
          dim {
            dim_value: 2
          }
        }
      }
    }
  }
  input {
    name: "Input2"
    type {
      tensor_type {
        elem_type: FLOAT
        shape {
          dim {
            dim_value: 3
          }
          dim {
            dim_value: 2
          }
        }
      }
    }
  }
  output {
    name: "Output"
    type {
      tensor_type {
        elem_type: FLOAT
        shape {
          dim {
            dim_value: 3
          }
          dim {
            dim_value: 2
          }
        }
      }
    }
  }
}
opset_import {
  version: 2
}

作成したモデルの可視化

先ほどのテキストファイルからモデルが構築できているであろうと予想できましたが、本当に期待したモデルになっているのか直感的にわかりにくいです。
そこで、Visualizing an ONNX Model を参考に、作ったモデルを可視化してみたいと思います。

以下のコマンドを実行し、ONNXで作成したモデルをSVG形式の画像ファイルに変換します。

 $ git clone https://github.com/onnx/onnx.git
 $ cd onnx
 $ python onnx/tools/net_drawer.py --input path/to/custom_model.onnx --output custom_model.dot --embed_docstring
 $ dot -Tsvg custom_model.dot -o custom_model.svg

上記のコマンドにより変換した画像を以下に示します。

中間データである Result ReLU が表示されていますが、期待したグラフが作成されていることがわかります。

おわりに

Pythonのonnxモジュールを利用することで、ONNX形式のモデルを読み込んだり作成したりできます。各フレームワークが提供するONNXのImporter/Exporterは、本記事で紹介したonnxモジュールを使って各フレームワークのグラフ構造をONNX形式のモデルに変換しています。
最近では、WGが立ち上がるなどONNXが普及しつつある一方、ONNXが比較的最近登場したモデル形式でもあり、全てのフレームワークがONNX形式のモデルに対応できていない状況です。ONNXのImporter/Exporterの開発が進み、将来的にはフレームワーク間の差異を気にすることなく演算ができるようになることを期待したいと思います。

参考情報

  1. 本記事の執筆時点では、著名な画像認識系のネットワークモデルしか配置されていませんが、今後標準化が進んでいくことで増えていくと予想されます

91
84
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
91
84

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?