ONNXでモデルを表現しよう
ONNX(Open Neural Network Exchange) は機械学習やディープラーニングなどのAIモデルを表現するためのオープンソースのフォーマットです。プラットフォームに依存しない、という特徴があり、ONNXモデルを経由することによりPyTorchで学習したモデルをTensorflowで推論することや、ワークステーション上のPythonスクリプトで学習したモデルをエッジデバイスのC++アプリで推論することなどが可能になります。
Raspberry Pi 5のAI KitであるHailo-8がONNXに対応しているとのことなので、ゴールはUbuntu24.04 + RTX3060Tiでモデルを学習させて、Raspberry Pi 5によりエッジで推論する構成を目指しているのですが、本記事では、まずMacBook Pro上でONNXを味見してみようということを目的に、遊んでみました。
関連記事
Netron
なお、本記事ではネットワークが変換できているかを確認するために、Netronを利用します。Netronはモデルを可視化できるツールで、インストールすることなくWebから利用可能です。Netronを使うことにより深層学習の各層の結合を確認できる他、各層の入力テンソルも確認することができます。下記の例は入力テンソルが (1, 3, 224, 224)
となっているResNetの例です。
学習済みのResNetをONNXモデルとして扱う
PyTorchのモデルをONNXモデルへ書き出す
まずPyTorchのtorchvision
にバンドルされている学習済みのresnet101
をONNXモデルに変換してみましょう。resnet101
は入力に1枚の画像、RGBチャネル、224x224ピクセルの画像をとり、出力に1000種類のクラスの確率をとるモデルです。
resnet101
をONNX形式へ変換するmodel_converter.py
の内容は以下となります。モデルをtorchvision.models
から読み込みresnet
変数に格納した後、推論モードへ切り替え、ダミーの入力データ(1, 3, 224, 224)
を準備し、onnx.export
を呼び出すことでONNXモデルのファイルを生成できます。
# 実行のベースとなるtorchをインポート
import torch
# torchvisionからモデルと画像変換機能をインポート
from torchvision import models
# PyTorchのONNX拡張
import torch.onnx as onnx
##############
# resnetの準備
##############
print(dir(models))
# モデルに含まれているresnetを、学習済みの状態で取得
resnet = models.resnet101(pretrained=True)
# resnetのネットワーク構造を表示
print(resnet)
# resnetを推論モードに切り替え
resnet.eval()
# ダミーデータを作成する
input_image = torch.zeros((1, 3, 224, 224))
# ONNX形式で書き出す
onnx.export(resnet, input_image, "model.onnx")
# 終了する
quit()
PyTorchのモデルからONNXモデルへの変換手順は以下の通りです。変換にはonnx
とonnxruntime
をインストールする必要があるため、実行前に仮想環境にpip
コマンドでインストールします。
# 作業ディレクトリを作成する
$ mkdir pytorchbox
$ cd pytorchbox
# 仮想環境を作成する
$ conda create --prefix ./env python=3.8
# 仮想環境を有効化する
$ conda activate ./env
# PyTorchをインストールする
(env) $ conda install pytorch::pytorch torchvision torchaudio -c pytorch
# onnxとonnxランタイムをインストールする
(env) $ pip install onnx
(env) $ pip install onnxruntime
# PyTorchモデルをONNXへ変換するプログラム
(env) $ ls *.py
model_converter.py
# 変換を実行する
(env) $ python model_converter.py
...
# ONNXモデルができていることを確認する
(env) $ ls *.onnx
model.onnx
生成した model.onnx
をNetronから開くことで、下記のようにネットワークを可視化可能です。入力テンソルが(1, 3, 224, 224)
で、出力テンソルが(1, 1000)
となっていることを確認することができますね。
ONNXモデルをPythonで読み込んで推論する
続いて、ONNXモデルをPythonで読み込んで推論をしてみましょう。基本的な構造は下記記事と同等で、変更されている箇所は、推論の呼び出し方や後処理周りです。
利用したパッケージ
ONNXモデルの実行にPyTorchは必要ありませんが、今回はカメラからの入力画像やテンソルを加工しやすいことからtorch
、torchvision
、PIL
をインポートしています。また、ONNXモデルを扱うためonnx
とonnxruntime
をインポートしています。
全体の流れ
ONNXモデルはort.InferenceSession
により読み込みます。モデルの入力層と出力層の名前を取得して、推論を実行する際(ort_session.run
)に指定します。推論の入力となるWEBカメラで撮影した画像をPyTorchにより(1, 3, 224, 224)
のテンソルに変換して.numpy().copy()
でnumpy配列とします。これをort_session.run
の引数に与えて推論を実行します。推論結果は1000クラスそれぞれの確率となりますので、確率を昇順に並べ、カメラに何が映ったかを判別します。今回は、出力層をNetronを見ると最後のGemm層で(x2048)
していましたので、2048で割ってパーセンテージに変換し、50%以上であれば出力するようにしました。
# 実行のベースとなるtorchをインポート
import torch
# torchvisionからモデルと画像変換機能をインポート
from torchvision import transforms
# 画像を扱うライブラリPILをインポート
from PIL import Image
# ONNXを制御するパッケージをインポート
import onnx
import onnxruntime as ort
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions
# 数値処理用のパッケージをインポート
import numpy as np
# Webカメラを制御するOpenCVをインポート
import cv2
# /dev/video0を指定
DEV_ID = 0
# 撮影するサイズ
WIDTH = 640
HEIGHT = 480
# 認識の閾値
THRESHOLD = 50
##############
# resnetの準備
##############
# ONNX実行時のオプションを設定する
options = SessionOptions()
options.intra_op_num_threads = 1
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
# ONNXモデルを読み込む
ort_session = ort.InferenceSession('model.onnx')
ort_session.disable_fallback()
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name
# 入力層の名前を取得
print("input:")
for session_input in ort_session.get_inputs():
print(session_input.name, session_input.shape)
# 出力層の名前を取得
print("output:")
for session_output in ort_session.get_outputs():
print(session_output.name, session_output.shape)
# resnetの分類(クラス)をlabelsに読み込む
with open('./imagenet_classes.txt') as f:
labels = [line.strip() for line in f.readlines()]
# OpenCVで撮影した画像をresnetの入力テンソルに変換する
preprocess = transforms.Compose([
transforms.ToPILImage(), # numpy.ndarray -> pillow
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
##############
# Webカメラに接続する
##############
cap = cv2.VideoCapture(DEV_ID)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, WIDTH)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, HEIGHT)
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
##############
# OpenCVで撮影し、resnetテンソルに変換
##############
while(True):
try:
# 1フレーム撮影する
ret, frame = cap.read()
if ret:
# 撮影したフレームを表示
cv2.imshow('frame', frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
# 画像をBGRからRGBの並びに変換
input_image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# 画像をresnetの入力テンソルに変換する
input_tensor = preprocess(input_image)
# 画像をバッチに内包する(1次元追加)
batch_t = torch.unsqueeze(input_tensor, dim=0)
# ONNXに渡すデータへと変換する
input_onnx = batch_t.numpy().copy()
##############
# 入力テンソルをresnetで推論、結果を表示する
##############
# カメラの画像をresnetで推論
# runの第一引数:出力層の名前
# runの第二引数:入力narray
out_onnx = ort_session.run([output_name], {input_name: input_onnx})
#out = torch.from_numpy(out_onnx[0][0]).clone()
#_, index = torch.max(out, 0)
out = out_onnx[0][0]
# 確率を昇順にソートする
sorted = np.sort(out)[::-1]
sorted_idx = np.argsort(out)[::-1]
# パーセンテージに変換
parcentage = (sorted * 100 / 2048) * 100
# 先頭のアイテムを取り出す
main_percentage = parcentage[0]
main_label = labels[sorted_idx[0]]
#print('%s(%s)' % (main_label, "{:.1f}".format(main_percentage, 1)))
# 確率一定以上であれば結果を出力
if main_percentage > THRESHOLD:
print('%s(%s)' % (main_label, "{:.1f}".format(main_percentage, 1)))
except KeyboardInterrupt:
print('abort')
スクリプトの実行手順は以下の通りです。
# 作業ディレクトリへ移動
(base) $ cd pytorchbox
(base) $ pwd
/Users/shino/pytorchbox
# 実行にはスクリプトの他に
# クラス名のリスト(imagenet_classes.txt)と
# ONNXモデルファイル(model.onnx:前手順でエクスポートしたもの)が必要となる
(base) $ ls
fetch_model_and_detect.py
model.onnx
env
imagenet_classes.txt
model_converter.py
# 仮想環境を有効にする
(base) $ conda activate ./env
(env) $
# 実行
(env) $ python fetch_model_and_detect.py
# book jacket(58.6)
# book jacket(56.5)
# book jacket(60.6)
# ... ターミナルを閉じて終了
認識できましたね!これでまた一歩前進です!
参考させていただいた記事
今回も多数の記事を参考にさせていただきました。ありがとうございます。