4
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

OpenCVAdvent Calendar 2024

Day 21

Pythonでパーティクルフィルタを実装してオブジェクト追跡をする方法

Posted at

はじめに

パーティクルフィルタは、動的な環境におけるオブジェクト追跡や位置推定のための強力なアルゴリズムです。本記事では、PythonとOpenCVを使用してパーティクルフィルタを実装し、動画中のオブジェクト追跡を行う方法を紹介します。


パーティクルフィルタの概要

パーティクルフィルタは、状態推定の問題をランダムサンプリングで解決するベイズフィルタの一種です。多数のパーティクル(仮説点)を用いて、システムの状態(位置や速度など)を推定します。以下の手順で進みます:

  1. 初期化: パーティクルをランダムに生成。
  2. 予測: パーティクルをシステムモデルに基づいて予測。
  3. 尤度計算: 各パーティクルが観測データに一致するかを評価。
  4. リサンプリング: 尤度に基づいてパーティクルを再配置。

実装の概要

今回のコードでは、以下の要素を実装しています:

  • Point2d: 2次元ポイントを表現するクラス。
  • Particle: パーティクルを表現するクラス。
  • ParticleFilter: パーティクルフィルタのメインクラス。
  • 動画処理: 動画を読み込み、フレームごとにパーティクルフィルタを適用。

実装のポイント

1. パーティクルの定義

パーティクルは、位置 (pos)、速度 (vel)、尤度 (like)、重み (weight) などを持つデータクラスとして定義します。

@dataclass
class Particle():
    pos: Point2d
    vel: Point2d
    like: float
    weight: float
    keep: bool

Point2d クラスを使って、位置と速度の計算を簡潔に行います。

2. 初期化と予測

パーティクルフィルタは最初にランダムな位置にパーティクルを生成します。その後、速度に基づいてパーティクルの位置を予測します。

def initialize(self):
    for i in range(self.n_particles):
        pt = Point2d(
            self._rs.randint(0, self._img_cols),
            self._rs.randint(0, self._img_rows)
        )
        self.particle_list.append(
            Particle(pt, Point2d(0, 0), 1.0, 0.0, False)
        )

def predict(self):
    for i_particle in self.particle_list:
        i_particle.pos += i_particle.vel * self.dt

3. 尤度の更新

各パーティクルの位置の色相値(Hue)と彩度(Saturation)を計算し、尤度を更新します。色相値70、彩度200を目標としています。

def _calc_like(self, img, i_particle):
    img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    vec_hsv = cv2.split(img_hsv)

    h = vec_hsv[0][i_particle.pos.y, i_particle.pos.x]
    s = vec_hsv[1][i_particle.pos.y, i_particle.pos.x]
    len_h = abs(70 - h)
    len_s = abs(200 - s)
    like = (len_h / 180) * 0.8 + (len_s / 255) * 0.2
    return 1 - like

4. リサンプリング

尤度の高いパーティクルを保持し、新しいパーティクルを生成します。

def resample(self):
    self.particle_list.sort()

    particle_list = [i_particle for i_particle in self.particle_list if i_particle.keep]
    like_sum = sum([i_particle.like for i_particle in particle_list])

    for i_particle in particle_list:
        i_particle.weight = i_particle.like / like_sum

    particle_list_new = []
    for i_particle in particle_list:
        n_particles_new = int(i_particle.weight * (self.n_particles - len(particle_list)))
        for _ in range(n_particles_new):
            r = self._rs.normal(loc=0.0, scale=self._img_rows + self._img_cols) * (1 - i_particle.like)
            ang = self._rs.uniform(-np.pi, np.pi)
            pt = Point2d(
                r * np.cos(ang) + i_particle.pos.x,
                r * np.sin(ang) + i_particle.pos.y
            )
            particle_list_new.append(
                Particle(pt, pt - i_particle.pos, i_particle.like, i_particle.weight, False)
            )

    self.particle_list = particle_list + particle_list_new

5. 描画

パーティクルを描画し、重心を計算して中心線を引きます。

def plot_particles(self, img):
    _img = img.copy()
    for i_particle in self.particle_list:
        if 0 < i_particle.pos.x < self._img_cols and 0 < i_particle.pos.y < self._img_rows:
            _img = cv2.circle(
                img=_img,
                center=i_particle.pos.xy,
                radius=2,
                color=(0, 0, 255),
                thickness=-1)
    return _img

def plot_center(self, img):
    _img = img.copy()
    center = sum([i_particle.pos for i_particle in self.particle_list]) / len(self.particle_list)

    cv2.line(_img,
             (center.x, 0),
             (center.x, self._img_rows),
             (0, 255, 255),
             3)
    cv2.line(_img,
             (0, center.y),
             (self._img_cols, center.y),
             (0, 255, 255),
             3)
    return _img

実装まとめ

main.py
import time
from dataclasses import dataclass

import cv2
import numpy as np


class Point2d():
    def __init__(self, x, y):
        self.x = int(x)
        self.y = int(y)

    def __add__(self, other):
        return self.__class__(self.x + other.x, self.y + other.y)

    def __sub__(self, other):
        return self.__class__(self.x - other.x, self.y - other.y)

    def __mul__(self, other):
        return self.__class__(round(self.x * other), round(self.y * other))

    def __truediv__(self, other):
        return self.__mul__(1 / other)

    def __radd__(self, other):
        return self.__class__(self.x + other, self.y + other)

    @property
    def xy(self):
        pass

    @xy.getter
    def xy(self):
        return (self.x, self.y)


@dataclass
class Particle():
    pos: Point2d
    vel: Point2d
    like: float
    weight: float
    keep: bool

    def __eq__(self, other):
        if isinstance(other, self.__class__):
            return self.like == other.like
        else:
            kls = other.__class__.__name__
            raise NotImplementedError(
                f'comparison between {self.__class__.__name__} and {kls} is not supported')

    def __ne__(self, other):
        return not self.__eq__(other)

    def __lt__(self, other):
        if isinstance(other, self.__class__):
            return self.like < other.like
        else:
            kls = other.__class__.__name__
            raise NotImplementedError(
                f'comparison between {self.__class__.__name__} and {kls} is not supported')

    def __le__(self, other):
        return self.__lt__(other) or self.__eq__(other)

    def __gt__(self, other):
        return not self.__le__(other)

    def __ge__(self, other):
        return not self.__lt__(other)


class ParticleFilter():
    def __init__(self, shape, n_particles=1000, dt=1, random_state=0, like_thresh=0.9, thresh_keep=100):
        self.n_particles = n_particles
        self.dt = dt
        self.random_state = random_state
        self.like_thresh = like_thresh
        self.thresh_keep = thresh_keep

        self._img_cols, self._img_rows = shape
        self._rs = np.random.RandomState(random_state)
        self.initialize()

    def initialize(self):
        self.particle_list = []
        for i in range(self.n_particles):
            pt = Point2d(
                self._rs.randint(0, self._img_cols),
                self._rs.randint(0, self._img_rows)
            )
            self.particle_list.append(
                Particle(pt, Point2d(0, 0), 1.0, 0.0, False)
            )

    def predict(self):
        for i_particle in self.particle_list:
            i_particle.pos += i_particle.vel * self.dt

    def update_like(self, img):
        # 色相値、彩度値から尤度を計算、更新
        for i_particle in self.particle_list:
            if 0 < i_particle.pos.x < self._img_cols and 0 < i_particle.pos.y < self._img_rows:
                i_particle.like = self._calc_like(img, i_particle)
            else:
                i_particle.like = 0

    def _calc_like(self, img, i_particle):
        """尤度計算メソッド.自分でカスタマイズする.
        """
        img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
        vec_hsv = cv2.split(img_hsv)

        # 尤度を計算
        h = vec_hsv[0][i_particle.pos.y, i_particle.pos.x]
        s = vec_hsv[1][i_particle.pos.y, i_particle.pos.x]
        len_h = abs(70 - h)
        len_s = abs(200 - s)
        like = (len_h / 180) * 0.8 + (len_s / 255) * 0.2
        return 1 - like

    def resample(self):
        # 尤度の昇順にソート
        self.particle_list.sort()

        # 尤度の高いパーティクルを残し、尤度の低いパーティクルを消す
        for i, i_particle in enumerate(self.particle_list):
            if i_particle.like > self.like_thresh or i > len(self.particle_list) - self.thresh_keep:
                i_particle.keep = True
            else:
                i_particle.keep = False
        particle_list = [i_particle for i_particle in self.particle_list if i_particle.keep]

        # 尤度の高いパーティクルの計数、尤度の合計
        like_sum = sum([i_particle.like for i_particle in particle_list])

        # 正規化した重みを計算
        for i_particle in particle_list:
            i_particle.weight = i_particle.like / like_sum

        # リサンプリング
        particle_list_new = []
        for i_particle in particle_list:
            n_particles_new = int(i_particle.weight * (self.n_particles - len(particle_list)))
            for _ in range(n_particles_new):
                r = self._rs.normal(loc=0.0, scale=self._img_rows + self._img_cols) * (1 - i_particle.like)
                ang = self._rs.uniform(-np.pi, np.pi)
                pt = Point2d(
                    r * np.cos(ang) + i_particle.pos.x,
                    r * np.sin(ang) + i_particle.pos.y
                )
                # velocityは位置の差分で計算
                particle_list_new.append(
                    Particle(pt, pt - i_particle.pos, i_particle.like, i_particle.weight, False)
                )

        # パーティクル更新
        self.particle_list = particle_list + particle_list_new

    def plot_particles(self, img):
        _img = img.copy()
        # パーティクル描画
        for i_particle in self.particle_list:
            if 0 < i_particle.pos.x < self._img_cols and 0 < i_particle.pos.y < self._img_rows:
                _img = cv2.circle(
                    img=_img,
                    center=i_particle.pos.xy,
                    radius=2,
                    color=(0, 0, 255),
                    thickness=-1)
        return _img

    def plot_center(self, img):
        _img = img.copy()
        # パーティクルの重心
        center = sum([i_particle.pos for i_particle in self.particle_list]) / len(self.particle_list)

        cv2.line(_img,
                 (center.x, 0),
                 (center.x, self._img_rows),
                 (0, 255, 255),
                 3)
        cv2.line(_img,
                 (0, center.y),
                 (self._img_cols, center.y),
                 (0, 255, 255),
                 3)
        return _img


def main():
    WIN_PF = 'main'
    RESIZE = (640, 360)  # (x, y)

    pf = ParticleFilter(
        shape=RESIZE,
        n_particles=1000,
        dt=1,
        random_state=0,
        like_thresh=0.9,
        thresh_keep=100)

    cap = cv2.VideoCapture('04-06.wmv')
    cv2.namedWindow(WIN_PF, cv2.WINDOW_NORMAL)

    while True:
        # 予測
        pf.predict()

        ret, img_src = cap.read()

        if not ret:
            break

        img_src = cv2.resize(img_src, RESIZE)
        pf.update_like(img_src)
        pf.resample()

        img_src = pf.plot_particles(img_src)
        img_src = pf.plot_center(img_src)

        cv2.imshow(WIN_PF, img_src)

        if cv2.waitKey(1) == ord('q'):
            break

        # 描画のための遅延
        time.sleep(0.2)

    cap.release()
    cv2.destroyAllWindows()


if __name__ == '__main__':
    main()

実行方法

1.必要なライブラリをインストールします。

pip install opencv-python numpy

2.04-06.wmv という動画ファイルを用意し、コードの同じディレクトリに配置します。動画はこちらのサイトのサンプルプログラムに同梱されています。

3.コードを実行します。

python main.py

4.動画の再生中にオブジェクト追跡結果を確認できます。


結果

実行すると、パーティクルが画面上でオブジェクトを追跡する様子が表示されます。各パーティクルの位置が赤い点で描画され、重心が黄色い線で示されます。

processed_04-06_fps30_1sec.gif


応用

この実装は、特定の色相値と彩度を追跡する例ですが、カスタマイズすることで以下のような応用が可能です:

  • 顔検出や物体検出: Haarカスケードや深層学習モデルと組み合わせる。
  • ロボットの位置推定: LiDARやカメラのデータと統合する。
  • その他の追跡対象: 特徴量やテンプレートマッチングを使用。

まとめ

本記事では、Pythonでパーティクルフィルタを実装して動画中のオブジェクトを追跡する方法を紹介しました。この実装を基に、さまざまな応用やカスタマイズに挑戦してみてください!

4
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?