0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

PyTorchによる推論デモとその困難ポイント

Last updated at Posted at 2025-03-25

はじめに

この文書ではYOLOv7で動画を推論するデモを作成しながら、アプリ作成における困難ポイントを記述していく。

作成するアプリは以下の通り動画をYOLOv7で推論しバウンディングボックスを描画してウインドウ表示をおこなう。

image.png

アプリ作成

git リポジトリ作成

今回のコードはgitを使用して管理していくが、git操作については一部を除いて割愛する。

YOLOv7リポジトリを取り込む

$ git submodule add https://github.com/WongKinYiu/yolov7.git

動画再生

まずOpenCVを使って動画再生だけ行うアプリケーションを作成する。この雛形は動画推論デモの基本となるため覚えておくと便利である。

demo.py
import cv2
import argparse

def main(input_path):
    cap = cv2.VideoCapture(input_path)

    while True:
        ret, frame = cap.read();
        if not ret:
            break

        cv2.imshow("output", frame)
        if ord('q') == cv2.waitKey(10):
            break

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--input',
                        default='/dev/video0',
                        help = 'input file or video device (default: /dev/video0)')

    args = parser.parse_args()
    main(args.input)

以下のように実行できる。

$ python demo.py -i <video file>

重みファイルのダウンロード

YOLO系に限らずPyTorchの推論モデルは訓練済みのモデルをダウンロードしてきて読み込ませる必要がある。

今回使用するyolov7.ptをダウンロードしておく。

$ curl -OL https://github.com/WongKinYiu/yolov7/releases/download/v0.1/yolov7.pt

困難ポイント
gitリポジトリは存在しているが重みファイルのダウンロードリンクが切れておりモデルの使用を諦めざるを得ないことがある。

モデルと重みの読み込み

demo.py(抜粋)
import sys
sys.path.append('./yolov7')
 
from models.experimental import attempt_load

(中略)

    torch_model = attempt_load(weights='./yolov7.pt', map_location='cpu').autoshape()
    torch_model.eval() # run for test

困難ポイント
Q. attempt_load()なんて便利な関数どうやって知ったのか。

A. 気合でコードを読みましょう。

推論コード追加

cap.read()で読み込んだ動画のフレーム(frame)を以下のように推論させる。

demo.py(抜粋)
        frame = cv2.resize(frame, (640, 480))
        im = torch.tensor(frame.astype(np.float32))

        im = im.permute(2,0,1) # [H, W, C] -> [C, H, W]
        im = im.unsqueeze(0) # [C, H, W] -> [1, C, H, W]
        im /= 255 # [0,255] -> [0.0,1.0]

        with torch.no_grad():
            res = torch_model(im)

RGBの画像ファイルは[H,W,C](CはチャンネルRGBなら3)となっているが、推論を行うときは概ね[1,C,H,W]にすることが多く、permute(), unsqueeze()はそういった変換をおこなうメソッドなので覚えておくと良い。

またほとんどの画像処理モデルは入力値は0.0-1.0の範囲で受け取ることが多いため255で割っている。

困難ポイント
SSD512のように入力サイズが512x512と決まっているモデルもあるがYOLO系は入力サイズは可変となっている。しかしサイズによっては推論途中でエラーになることもあり結局どういうサイズにすればよいかよくわからない。(誰か教えて)
あと入力サイズが大きいとメモリ不足になることもある。

バウンディングボックス算出と描画

demo.py(抜粋)
from utils.general import non_max_suppression
from utils.plots import plot_one_box

(中略)

        out = non_max_suppression(res[0], conf_thres=0.25, iou_thres=0.45)
        det = out[0]

        for *xyxy, conf, cls in reversed(det):
            plot_one_box(xyxy, frame, label=coco_names[int(cls)], color=(0,0,255), line_thickness=2)

non_max_suppression()はバウンディングボックスを算出する関数で物体検出のリポジトリでは必ずと言ってよいほど持っている。(NMSと略されることもある。)

またplot_one_boxという関数でframeにバウンディングボックスを書き込んでいる。

困難ポイント
Q. plot_one_box()なんて便利な関数どうやって知ったのか。

A. 気合でコードを読みましょう。

困難ポイント
実は推論モデルで推論するところまではそれほど難しくはない。難しいのは推論の結果出てきたテンソルをどう解釈すればよいかというところである。

推論モデルに詳細なドキュメントがあることは非常に稀で、既存コードを頑張って読んだり、時には元となった論文まで読み込む必要がある。

コード全体

demo.py
import sys
sys.path.append('./yolov7')

import numpy as np
import torch
import cv2
import argparse
import yaml

from utils.general import non_max_suppression
from utils.plots import plot_one_box

from models.experimental import attempt_load

def main(input_path):
    with open('yolov7/data/coco.yaml') as yml:
        coco = yaml.safe_load(yml)
    coco_names = coco['names']

    cap = cv2.VideoCapture(input_path)

    torch_model = attempt_load(weights='./yolov7.pt', map_location='cpu').autoshape()
    torch_model.eval() # run for test

    while True:
        ret, frame = cap.read();
        if not ret:
            break

        frame = cv2.resize(frame, (640, 480))
        im = torch.tensor(frame.astype(np.float32))

        im = im.permute(2,0,1) # [H, W, C] -> [C, H, W]
        im = im.unsqueeze(0) # [C, H, W] -> [1, C, H, W]
        im /= 255 # [0,255] -> [0.0,1.0]

        with torch.no_grad():
            res = torch_model(im)

        out = non_max_suppression(res[0], conf_thres=0.25, iou_thres=0.45)
        det = out[0]

        for *xyxy, conf, cls in reversed(det):
            plot_one_box(xyxy, frame, label=coco_names[int(cls)], color=(0,0,255), line_thickness=2)

        cv2.imshow("output", frame)
        if ord('q') == cv2.waitKey(10):
            break

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--input',
                        default='/dev/video0',
                        help = 'input file or video device (default: /dev/video0)')

    args = parser.parse_args()
    main(args.input)
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?