7
4

More than 1 year has passed since last update.

StrongSORTを用いた物体追跡

Last updated at Posted at 2022-07-24

自分で作成した物体検出モデルを用いて,追跡したい方には必見の内容です.
この記事では,どんな物体検出モデルを用いても,その物体を追跡することが可能な方法を説明します.
今回は,以下のGitHubのリポジトリを使用し,説明していきます.

目次

  • OpenCVの顔検出モデルを用いた顔追跡
  • 他の物体検出モデルを用いた追跡方法

環境

  • Google Colaboratory

OpenCVの顔検出モデルを用いた顔追跡

今回使用するGitHubのリポジトリをクローンします.

$ git clone https://github.com/ysenkun/faces-detection-strongsort.git

OpenCVの顔検出モデルである「haarcascade_frontalface_default.xml」をダウンロードします.faces-detection-strongsortのディレクトリ内で以下のコマンドを実行してください.

$ wget -nc https://github.com/opencv/opencv/blob/master/data/haarcascades/haarcascade_frontalface_default.xml -O ./haarcascade_frontalface_default.xml

必要なライブラリをpipでインストールします.

$ pip3 install -r requirements.txt

それでは,実際に動かします.追跡したい顔動画のパスを入力し,実行してください.実行するとoutput.mp4で以下ような動画が出力されます.

$ python3 track.py --source vid.mp4 # video path

動作結果

face

他の物体検出モデルを用いた追跡方法

変更点は,以下の3つの関数の中身です.
1つ目が init関数を書き換えます,この関数では,物体検出モデルのパスやStrongSORTの設定など行なっています.今回は,モデルを読み込むためのパスを変更します.

track.py
def __init__(self, arg):
    #opencv model for face detection
-   self.face_cascade_path = 'haarcascade_frontalface_default.xml'
-   self.face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + self.face_cascade_path)

+   self.your_modle_path = #your_model_path    

2つ目は,any_model関数を書き換えます.この関数では,物体検出を行い,検出されたバウンディングボックスをStrongSORT用に変換します.はじめに,検出されたバウンディングボックスを変数(bbox)に格納します.次に検出されたバウンディングボックスの座標をStrongSORT用に[x_center, y_center, width, height]になるように変更します.また,confsとclssそれぞれにaccuracyとclassを格納します.
※クラス(class)とは,検出された物体の名称を表す.ここでは,数値に置き換えて変数に格納する.

track.py
def any_model(self,frame):
-   src_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
-   bbox = self.face_cascade.detectMultiScale(src_gray)

+   bbox = #your_object_detection
   
    outputs = []
    confs = []

    #Change annotation coordinates for StrongSORT
    #From [x_topleft, y_topleft, width, height] to [x_center, y_center, width, height]
    if bbox is not None and len(bbox):
        x = torch.tensor(bbox)
            
        xywhs = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
-       xywhs[:, 0] = (x[:, 0] + x[:, 2]/2) # x center
-       xywhs[:, 1] = (x[:, 1] + x[:, 3]/2) # y center

+       xywhs[:, 0] = x_center
+       xywhs[:, 1] = y_center
+       xywhs[:, 2] = width
+       xywhs[:, 3] = height
            
        #Static confs(accuracy) and clss(class) because opencv face detection is used
-       confs = torch.tensor([0.9 for i in range(len(bbox))])
-       clss = torch.tensor([0 for i in range(len(bbox))])

+       confs = torch.tensor([#accuracy_of_your_model])
+       clss = torch.tensor([#class_of_your_model])

        #Run StorngSORT
        outputs = self.strongsort.update(xywhs.cpu(), confs.cpu(), clss.cpu(), frame)
            
    return outputs,confs

最後にannotation関数を書き換えます.この関数は,検出された物体にアノテーション行なっています.ここでは,あなたのモデルのクラスに応じて,label名が変更するようにしてください.
※clssの数値に応じて,物体名を変更できるようにする

track.py
def annotation(self, frame, output, conf):
    bboxes = output[0:4]
    id = int(output[4])
    clss = int(output[5])
-   label = None #Make the object name change to match the clss number

+   label = #your_model_class  

次回

次回の記事では,実際に顔以外の物体を追跡したいと思います.

GitHub
MyHP

7
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
4