0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

PyTorch内蔵の学習済みResNetとMacBook Pro搭載カメラで物体検出の推論をする

Last updated at Posted at 2024-11-15

前回はM4 Pro MacBook Pro上にPyTorchの実行環境を構築しました。

今回は、PyTorchのtorchvisionに含まれている学習済みのresnet101を使ってWebカメラで撮影した物体を検出する推論を実装してみました。torchvisionは、人気の高いデータセット、モデルアーキテクチャ、コンピュータビジョンでよく利用される画像の加工機能を含むライブラリです。resnet101は物体検出を行うネットワークで、101層のニューラルネットワーク層から構成されています。

スクリーンショット 2024-11-16 7.49.34.png

実行環境

ハードウェアリソースとしてMacBook Proを使用しました。 WebカメラはMacBook Pro内蔵カメラを利用しています。 ソフトウェアの実装にはVisual Studio Codeを利用しました。Visual Studio Codeには、下記の拡張機能を入れています。

  • Japanese Language Pack
  • Python
  • Python Debugger
  • Pylance

Pythonの仮想環境を作成し、必要なコンポーネントをインストールする

作業ディレクトリを作成し、condaコマンドでPythonの仮想環境を作成します。作成したPython仮想環境に入り、PyTorchに加えて、Webカメラを制御するOpenCVと連携するためのopencv-pythonをインストールします。

(base) $ mkdir pytorchbox
(base) $ cd pytorchbox
(base) $ conda create --prefix ./env python=3.8
(base) $ conda activate ./env
(env) $ conda install pytorch::pytorch torchvision torchaudio -c pytorch
(env) $ pip install opencv-python

torchvisionに含まれているモデル

torchvisionに含まれているモデルは、以下の手順で確認できます。すべて学習済みのモデルなので、学習用のデータセットを集めることなく、すぐに物体検知のプログラムを組むことができます。勿論、学習させることもできますので、特に、ネットワークが既に決まっている方、はじめて推論プログラムを書く方にオススメです。

(env) $ python
>>> from torchvision import models
>>> print(dir(models))
['AlexNet', 'AlexNet_Weights', 'ConvNeXt', ...
'ResNet', 'ResNet101_Weights', 'ResNet152_Weights', ...
'resnet', 'resnet101', 'resnet152', ...
'wide_resnet101_2', 'wide_resnet50_2']

学習済みResNetの分類クラス一覧をダウンロードする

torchvisionに含まれているResNetはデータセットimagenetを使って学習されています。作業ディレクトリにimagenetのクラス一覧をダウンロードしておきましょう。

# 特にenv以下で作業する必要はないですが...
(env) $ wget https://raw.githubusercontent.com/pytorch/hub/refs/heads/master/imagenet_classes.txt
(env) $ ls
env   imagenet_classes.txt

Visual Studio Codeで作業ディレクトリを開く

Visual Studio Codeで作業ディレクトリを開き、ターミナルでPythonの仮想環境を有効化します。また、 Pythonファイルを作成します。今回はtesttorch.pyという名前のファイルを作成しました。

# Visual Studio Codeのターミナルも仮想環境を有効化する
(base) $ conda activate ./env
(env) $

スクリーンショット 2024-11-16 5.51.38.png

OpenCVでWebカメラの画像を採取し、推論する

推論モデルと撮影処理を別々に準備し、最後に結合します。

推論側の実装

まず推論側では、推論に使うresnet101torchvisionから読み込み、MacBookのアクセラレータ(MPS:Metal Perfomance Shaders)に転送します。また、入力画像をresnetの入力テンソルに変換するためのpreprocessも定義します。この定義にはtorchvisionの提供するtransformsを利用しました。transforms.Composeを利用することにより処理をまとめることが可能で、前処理を並べておくと、入力画像に対して順番に適用するフィルタとなってくれます。本例では、5段階の前処理を適用しています。

  • OpenCVのRGB形式を画像を扱うPIL対応の形式に変換
  • 画像のリサイズ(256 x 256px に縮小)
  • 画像のクロップ(中央の224 x 224pxを抽出)
  • テンソルに変換
  • RGBそれぞれの値をNormalize

Normalizeに指定したパラメータ

ここでのポイントは、画像分類データセットを構成している画像に含まれるRGBの平均と標準偏差mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225] なので、この値を使ってNormalizeしている点です。ですので、異なる学習データセットに変更してresnetを利用する場合は、そのデータセットの平均値と標準偏差にあわせる必要があります。

撮影側の実装

OpenCV(撮影)側では、MacBook Pro内蔵カメラにcv2.VideoCapture()メソッドでアクセスし、撮影サイズとバッファ深度を指定します。OpenCVで撮影した画像はBGR形式となっていますので、PILライブラリの扱いやすいRGB形式へと変換し、推論モデルの入力へと変換するため前手順で実装したpreprocessに投入します。shapeを確認し、入力バッチとなるtorch.Size([1, 3, 224, 224])になっていればOKです。最後に.to("mps")メソッドを利用して入力バッチをアクセラレータへ転送します。

撮影結果をresnetで分類する

最後に、撮影側の実装で作成したbatch_tresnetへ入力し、出力テンソルoutを取得します。出力テンソルには分類クラスそれぞれの確率が含まれています。maxメソッドを使って並び替えることにより先頭の要素のインデックスを得ることができますのでlabelsを参照することで、カメラが何を撮影したのか認識することができます。先頭の確率が低い場合、特に何も撮影できていない可能性がありますので、今回はif main_percentage > 50:として、50%以上の場合のみラベルを出力するものとしました。

ソースコード全体

以上の内容を実装した結果が以下となります。

# 実行のベースとなるtorchをインポート
import torch
# torchvisionからモデルと画像変換機能をインポート
from torchvision import models
from torchvision import transforms
# 画像を扱うライブラリPILをインポート
from PIL import Image
# Webカメラを制御するOpenCVをインポート
import cv2

# /dev/video0を指定
DEV_ID = 0
# 撮影するサイズ
WIDTH = 640
HEIGHT = 480

##############
# resnetの準備
##############
# モデルに含まれているresnetを、学習済みの状態で取得
resnet = models.resnet101(pretrained=True)
# resnetのネットワーク構造を表示
print(resnet)
# resnetを推論モードに切り替え
resnet.eval()
# resnetのモデルをMPS(アクセラレータ)へ転送
resnet.to("mps")
# resnetの分類(クラス)をlabelsに読み込む
with open('./imagenet_classes.txt') as f:
    labels = [line.strip() for line in f.readlines()] 

# OpenCVで撮影した画像をresnetの入力テンソルに変換する
preprocess = transforms.Compose([
    transforms.ToPILImage(), # numpy.ndarray -> pillow
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

##############
# Webカメラに接続する
##############
cap = cv2.VideoCapture(DEV_ID)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, WIDTH)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, HEIGHT)
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)

##############
# OpenCVで撮影し、resnetテンソルに変換
##############
while(True):
    try:
        # 1フレーム撮影する
        ret, frame = cap.read()
        if ret:
            # 撮影したフレームを表示
            cv2.imshow('frame', frame)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
            # 画像をBGRからRGBの並びに変換
            input_image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            # 画像をresnetの入力テンソルに変換する
            input_tensor = preprocess(input_image)
            # 画像をバッチに内包する(1次元追加)
            batch_t = torch.unsqueeze(input_tensor, dim=0)
            #print(input_tensor.shape) # -->> torch.Size([3, 224, 224])
            #print(batch_t.shape) # -->> torch.Size([1, 3, 224, 224])
            # 入力テンソルをアクセラレータに転送
            batch_t = batch_t.to("mps")

            ##############
            # 入力テンソルをresnetで推論、結果を表示する
            ##############
            # 入力テンソルに変換したカメラの画像をresnetで推論
            out = resnet(batch_t)
            # 最大確率のインデックスを取得
            _, index = torch.max(out, 1)
            # 確率をパーセンテージ表記に変更
            percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
            main_percentage = percentage[index[0]].item()
            main_label = labels[index[0]]
            # 確率50%以上であれば結果を出力
            if main_percentage > 50:
                print('%s(%s)' % (main_label, "{:.1f}".format(main_percentage, 1)))

    except KeyboardInterrupt:
        print('abort')

もちろん、別の学習済みモデルも使えますのでresnetを初期化している箇所を以下のように書き換えるだけでresnet101の推論器からresnext50_32x4dの推論器に切り替え、推論時間を短縮することが可能です。しかも.to("mps")でアクセラレータも使えてしまうとは...PyTorchの学習済みモデル、本当に便利だなぁ。

# モデルに含まれているresnetを、学習済みの状態で取得
#resnet = models.resnet101(pretrained=True)
resnet = models.resnext50_32x4d(pretrained=True)

以上、ありがとうございました。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?