1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

PyTorchの学習済みResNetをONNXモデルに変換しPythonで推論する

Last updated at Posted at 2024-11-25

ONNXでモデルを表現しよう

ONNX(Open Neural Network Exchange) は機械学習やディープラーニングなどのAIモデルを表現するためのオープンソースのフォーマットです。プラットフォームに依存しない、という特徴があり、ONNXモデルを経由することによりPyTorchで学習したモデルをTensorflowで推論することや、ワークステーション上のPythonスクリプトで学習したモデルをエッジデバイスのC++アプリで推論することなどが可能になります。

Raspberry Pi 5のAI KitであるHailo-8がONNXに対応しているとのことなので、ゴールはUbuntu24.04 + RTX3060Tiでモデルを学習させて、Raspberry Pi 5によりエッジで推論する構成を目指しているのですが、本記事では、まずMacBook Pro上でONNXを味見してみようということを目的に、遊んでみました。

関連記事

Netron

なお、本記事ではネットワークが変換できているかを確認するために、Netronを利用します。Netronはモデルを可視化できるツールで、インストールすることなくWebから利用可能です。Netronを使うことにより深層学習の各層の結合を確認できる他、各層の入力テンソルも確認することができます。下記の例は入力テンソルが (1, 3, 224, 224) となっているResNetの例です。

名称未設定 1.png

学習済みのResNetをONNXモデルとして扱う

PyTorchのモデルをONNXモデルへ書き出す

まずPyTorchのtorchvisionにバンドルされている学習済みのresnet101をONNXモデルに変換してみましょう。resnet101は入力に1枚の画像、RGBチャネル、224x224ピクセルの画像をとり、出力に1000種類のクラスの確率をとるモデルです。

resnet101をONNX形式へ変換するmodel_converter.pyの内容は以下となります。モデルをtorchvision.modelsから読み込みresnet変数に格納した後、推論モードへ切り替え、ダミーの入力データ(1, 3, 224, 224)を準備し、onnx.exportを呼び出すことでONNXモデルのファイルを生成できます。

# 実行のベースとなるtorchをインポート
import torch
# torchvisionからモデルと画像変換機能をインポート
from torchvision import models
# PyTorchのONNX拡張
import torch.onnx as onnx

##############
# resnetの準備
##############
print(dir(models))
# モデルに含まれているresnetを、学習済みの状態で取得
resnet = models.resnet101(pretrained=True)
# resnetのネットワーク構造を表示
print(resnet)
# resnetを推論モードに切り替え
resnet.eval()

# ダミーデータを作成する
input_image = torch.zeros((1, 3, 224, 224))
# ONNX形式で書き出す
onnx.export(resnet, input_image, "model.onnx")
# 終了する
quit()

PyTorchのモデルからONNXモデルへの変換手順は以下の通りです。変換にはonnxonnxruntimeをインストールする必要があるため、実行前に仮想環境にpipコマンドでインストールします。

# 作業ディレクトリを作成する
$ mkdir pytorchbox
$ cd pytorchbox 

# 仮想環境を作成する
$ conda create --prefix ./env python=3.8
# 仮想環境を有効化する
$ conda activate ./env
# PyTorchをインストールする
(env) $ conda install pytorch::pytorch torchvision torchaudio -c pytorch
# onnxとonnxランタイムをインストールする
(env) $ pip install onnx
(env) $ pip install onnxruntime

# PyTorchモデルをONNXへ変換するプログラム
(env) $ ls *.py
model_converter.py

# 変換を実行する
(env) $ python model_converter.py
...

# ONNXモデルができていることを確認する
(env) $ ls *.onnx
model.onnx

生成した model.onnx をNetronから開くことで、下記のようにネットワークを可視化可能です。入力テンソルが(1, 3, 224, 224)で、出力テンソルが(1, 1000)となっていることを確認することができますね。

スクリーンショット 2024-11-25 21.54.24.png

ONNXモデルをPythonで読み込んで推論する

続いて、ONNXモデルをPythonで読み込んで推論をしてみましょう。基本的な構造は下記記事と同等で、変更されている箇所は、推論の呼び出し方や後処理周りです。

利用したパッケージ

ONNXモデルの実行にPyTorchは必要ありませんが、今回はカメラからの入力画像やテンソルを加工しやすいことからtorchtorchvisionPILをインポートしています。また、ONNXモデルを扱うためonnxonnxruntimeをインポートしています。

全体の流れ

ONNXモデルはort.InferenceSessionにより読み込みます。モデルの入力層と出力層の名前を取得して、推論を実行する際(ort_session.run)に指定します。推論の入力となるWEBカメラで撮影した画像をPyTorchにより(1, 3, 224, 224)のテンソルに変換して.numpy().copy()でnumpy配列とします。これをort_session.runの引数に与えて推論を実行します。推論結果は1000クラスそれぞれの確率となりますので、確率を昇順に並べ、カメラに何が映ったかを判別します。今回は、出力層をNetronを見ると最後のGemm層で(x2048)していましたので、2048で割ってパーセンテージに変換し、50%以上であれば出力するようにしました。

# 実行のベースとなるtorchをインポート
import torch
# torchvisionからモデルと画像変換機能をインポート
from torchvision import transforms
# 画像を扱うライブラリPILをインポート
from PIL import Image

# ONNXを制御するパッケージをインポート
import onnx
import onnxruntime as ort
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions
# 数値処理用のパッケージをインポート
import numpy as np
# Webカメラを制御するOpenCVをインポート
import cv2

# /dev/video0を指定
DEV_ID = 0
# 撮影するサイズ
WIDTH = 640
HEIGHT = 480
# 認識の閾値
THRESHOLD = 50

##############
# resnetの準備
##############
# ONNX実行時のオプションを設定する
options = SessionOptions()
options.intra_op_num_threads = 1
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL

# ONNXモデルを読み込む
ort_session = ort.InferenceSession('model.onnx')
ort_session.disable_fallback()
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name

# 入力層の名前を取得
print("input:")
for session_input in ort_session.get_inputs():
    print(session_input.name, session_input.shape)
# 出力層の名前を取得
print("output:")
for session_output in ort_session.get_outputs():
    print(session_output.name, session_output.shape)

# resnetの分類(クラス)をlabelsに読み込む
with open('./imagenet_classes.txt') as f:
    labels = [line.strip() for line in f.readlines()] 

# OpenCVで撮影した画像をresnetの入力テンソルに変換する
preprocess = transforms.Compose([
    transforms.ToPILImage(), # numpy.ndarray -> pillow
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

##############
# Webカメラに接続する
##############
cap = cv2.VideoCapture(DEV_ID)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, WIDTH)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, HEIGHT)
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)

##############
# OpenCVで撮影し、resnetテンソルに変換
##############
while(True):
    try:
        # 1フレーム撮影する
        ret, frame = cap.read()
        if ret:
            # 撮影したフレームを表示
            cv2.imshow('frame', frame)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
            # 画像をBGRからRGBの並びに変換
            input_image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            # 画像をresnetの入力テンソルに変換する
            input_tensor = preprocess(input_image)
            # 画像をバッチに内包する(1次元追加)
            batch_t = torch.unsqueeze(input_tensor, dim=0)
            # ONNXに渡すデータへと変換する
            input_onnx = batch_t.numpy().copy()

            ##############
            # 入力テンソルをresnetで推論、結果を表示する
            ##############
            # カメラの画像をresnetで推論
            # runの第一引数:出力層の名前
            # runの第二引数:入力narray
            out_onnx = ort_session.run([output_name], {input_name: input_onnx})

            #out = torch.from_numpy(out_onnx[0][0]).clone()
            #_, index = torch.max(out, 0)
            out = out_onnx[0][0]
            # 確率を昇順にソートする
            sorted = np.sort(out)[::-1]
            sorted_idx = np.argsort(out)[::-1]
            # パーセンテージに変換
            parcentage = (sorted * 100 / 2048) * 100 
            # 先頭のアイテムを取り出す
            main_percentage = parcentage[0]
            main_label = labels[sorted_idx[0]]
            #print('%s(%s)' % (main_label, "{:.1f}".format(main_percentage, 1)))
            
            # 確率一定以上であれば結果を出力
            if main_percentage > THRESHOLD:
                print('%s(%s)' % (main_label, "{:.1f}".format(main_percentage, 1)))

    except KeyboardInterrupt:
        print('abort')

スクリプトの実行手順は以下の通りです。

# 作業ディレクトリへ移動
(base) $ cd pytorchbox 
(base) $ pwd
/Users/shino/pytorchbox

# 実行にはスクリプトの他に
# クラス名のリスト(imagenet_classes.txt)と
# ONNXモデルファイル(model.onnx:前手順でエクスポートしたもの)が必要となる
(base) $ ls
fetch_model_and_detect.py
model.onnx
env
imagenet_classes.txt
model_converter.py

# 仮想環境を有効にする
(base) $ conda activate ./env 
(env) $ 

# 実行
(env) $ python fetch_model_and_detect.py 
# book jacket(58.6)
# book jacket(56.5)
# book jacket(60.6)
# ... ターミナルを閉じて終了

スクリーンショット 2024-11-26 1.24.43.png

認識できましたね!これでまた一歩前進です!

参考させていただいた記事

今回も多数の記事を参考にさせていただきました。ありがとうございます。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?