LoginSignup
28
28

More than 5 years have passed since last update.

本記事は、TensorFlow Advent Calendar 2018 の9日目の記事です。
過去に投稿した記事ですが、Advent Calenderに投稿するにあたって記事を更新しました。

本記事では、ONNX形式で作成されたモデルを、TensorFlowをバックエンドとして実行する方法について説明します。

ONNX

ONNXに関しては、別の記事「ONNX形式のモデルを扱う」でまとめていますので、よろしければ参考にしてみてください。

ONNX-TensorFlow

ONNX形式のモデルをTensorFlowで実行するためには、ONNX形式のモデルをTensorFlowのグラフ(TFグラフ)に変換するためのコンバータ ONNX-TensorFlow が必要になります。
ONNX-TensorFlowはONNXコミュニティで開発され、GitHub上で公開されています。

onnx/onnx-tensorflow (GitHub)

ONNXコミュニティ では、TensorFlowに関係するプロジェクトとして onnx/tensorflow-onnx も公開されていますが、こちらはTensorFlowで作成したモデルをONNX形式のモデルへ変換するものです。将来的にこれらは統合される可能性がある とのことですが、現状2つに分かれているため注意が必要です。

※ ONNX-TensorFlowは現在も開発中であり、ONNXで提供されるすべてのOperationに対応できていません。対応しているOperation一覧は、ドキュメントから確認することができます。

ONNX-TensorFlowは、以下の一連のコマンドでインストールすることができます。

pip install tensorflow
pip install onnx-tf

前準備

ONNX形式のモデルを取得する

ONNX形式のモデルは、onnx/models に置かれています。今回は、VGG19モデルrelease 1.1 版を使ってみます。
モデルが置かれているリンクがあるので、ダウンロードします。ファイルサイズは500MB程度です。
ダウンロードが完了したらファイルを解凍し、本記事で作成するPythonのスクリプトが配置されたディレクトリに models ディレクトリを作成し、解凍したファイルを配置します。

推論する画像を取得する

本記事で利用する、推論対象の画像を取得します。推論対象の画像を探していたところ、onnx/tutorials にいくつかテスト用の画像が置かれていたので、猫の画像 を使うことにしました。
画像データをダウンロードし、本記事で作成するPythonのスクリプトが配置されたディレクトリに data ディレクトリを作成し、画像データを配置します。

ImageNetのクラスIDとラベル名が紐づいたデータを用意する

取得したVGG19のモデルは、訓練時にデータセットとしてImageNetの ILSVRC2014 を利用しています。
TensorFlowは、ImageNetで定義された1000個のラベルに対してその予測が最大となる、ILSVRC204のクラスIDしか出力しません。このためクラスIDから、そのクラスIDが示すラベル名を確認できるようにする必要があります。
幸いなことに、ImageNetで定義された1000個のラベルと英語でのラベル名を紐づけたリスト が有志によって作られています。
後でサンプルの実行結果とこのリストを照らし合わせ、推論した結果を確認します。

ONNX形式のモデルをTensorFlowで実行する

ONNX形式のモデルをTensorFlowで実行するためには、以下の手順を踏む必要があります。

  1. モデルに対する入力画像データの読み込みと加工
  2. ONNX形式のモデルの読み込み
  3. ONNX-TensorFlowを使い、ONNX形式のモデルをTensorFlowで実行
  4. 結果出力

上記の手順に従って、ONNX形式のモデルをTensorFlowで実行するサンプルを次に示します。
なお、本サンプルを実行するためにはPILパッケージのインストールが必要です。以下のコマンドを実行し、PILパッケージをインストールしておきます。

pip install pillow

サンプル

import onnx_tf.backend
import onnx

import numpy as np
from PIL import Image

img_path = "data/cat.jpg"
model_path = "models/vgg19/model.onnx"


def main():
    # 画像の読み込みと加工
    img = Image.open(img_path)
    img = img.resize((224, 224))
    arr = np.asarray(img, dtype=np.float32)[np.newaxis, :, :, :]
    arr = arr.transpose(0, 3, 1, 2)

    # ONNX形式のモデル読み込み
    onnx_model = onnx.load(model_path)

    # TensorFlowでONNX形式のモデルを実行
    tf_model = onnx_tf.backend.prepare(onnx_model, device='CPU')
    result = tf_model.run(arr)

    # 確率が高い順にクラスIDを昇順で出力
    prob = np.argsort(result.prob_1[0])[::-1]
    print("===== [Prob] =====")
    print(prob)

    # 確率が上位5個のクラスIDとその確率を表示する
    print("===== [TOP 5] =====")
    for i in range(5):
        print("{}: {}%".format(prob[i], result.prob_1[0][prob[i]] * 100))


if __name__ == "__main__":
    main()

解説

1. モデルに対する入力画像データの読み込みと加工

モデルの入出力データの形式は、それぞれ実行するモデルによって決められています。
今回使うモデルはVGG19で、入出力データ形式は onnx/models に書かれています。画像を読み込んだ後、ここに書かれている入力データの形式に合わせて画像データを加工する必要があります。
以下は、画像データの読み込みと加工を行うソースコードです。

    # 画像の読み込みと加工
    img = Image.open(img_path)
    img = img.resize((224, 224))
    arr = np.asarray(img, dtype=np.float32)[np.newaxis, :, :, :]
    arr = arr.transpose(0, 3, 1, 2)

本サンプルでは、PILを使って画像データを読み込みます。
PILは画像を扱う時に便利なライブラリですが、読み込んだ画像データは、numpyの配列で (width, height, color) の形式となっているため、TensorFlowではそのまま扱うことができません。このため、TensorFlowで扱える形式 (batch, color, width, height) に変換する必要があります。

2. ONNX形式のモデルの読み込み

ONNX形式のモデルの読み込みについては、こちらの記事 を参照してください。ここでは、説明を省略します。

3. ONNX-TensorFlowを使い、ONNX形式のモデルをTensorFlowで実行

ONNX-TensorFlowを使って、読み込んだONNX形式のモデルを実行するソースコードを以下に示します。ONNX-TensorFlowを使うためには、onnx_tfパッケージをimportする必要があります。

    # TensorFlowでONNX形式のモデルを実行
    tf_model = onnx_tf.backend.prepare(onnx_model, device='CPU')
    result = tf_model.run(arr)

本サンプルではCPUで実行するため、onnx_tf.backend.prepare の引数 device'CPU' を指定します。

4. 結果出力

本サンプルの実行結果は、以下の2つを出力します。

  • 確率が高い順に並べたクラスIDの配列
  • 確率が高い上位5個のクラスIDとその確率

結果

サンプルの実行結果を以下に示します。

===== [Prob] =====
[285 287 281 284 280 282 283 279 278 539 968 277 852 463  17 761 611 904
 383 722 700 987 572 434 876 412 384 855 356 470 861 725 778 756 659 738

...<snip>...

  50 339 449 142 436 354 612 874 547 705]
===== [TOP 5] =====
285: 43.46124231815338%
287: 17.308275401592255%
281: 14.673717319965363%
284: 5.337991192936897%
280: 4.124986007809639%

上記の結果をもとに、ImageNetで定義された1000個のラベルと英語でのラベル名を紐づけたリスト を使い、クラスIDとラベル名を上位5個について対応付けてみます。

クラスID ラベル名 確率
285 Egyptian cat 43.5%
287 lynx, catamount 17.3%
281 tabby, tabby cat 14.7%
284 Siamese cat, Siamese 5.3%
280 grey fox, gray fox, Urocyon cinereoargenteus 4.1%

このモデルは、約4割の自信をもって入力した画像が「Egyptian cat」であると推測しています。
また、上位5個を見ると約8割の自信をもって猫であると推測しています。
入力した画像が猫画像であることから、このモデルは入力画像を猫であると認識できていることがわかります。

※ 参考:Egyptian catのラベルで対応付けられた訓練用の画像一覧

おわりに

公開されているONNX形式のモデルVGG19をTensorFlowをバックエンドとして実行し、その結果を確かめてみました。
本サンプルは全体で20行程度であり、ONNX形式のモデルをTensorFlowで実行するところに限れば、コード量としてはたったの3行だけです。ONNX-TensorFlowはまだ開発段階であるため、全てのONNX形式のモデルを実行できるとは限りませんが、非常に簡単な手順でONNX形式のモデルをTensorFlowで実行できることがわかると思います。

28
28
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
28
28