まずはじめに 各技術の紹介
Segment Anything Model(SAM)とは
公式サイト
Github
概要
・Meta社(Facebook)が公開した、
・セグメンテーションモデル(塗りつぶし)
・ゼロショットでも高性能(トレーニングが必要ない)
Xmemとは
原文
・XMemは長いビデオのビデオオブジェクトセグメンテーションに使用される「アーキテクチャ」。
・このアーキテクチャの特徴記憶ストア←(これが重要)これがオブジェクトの位置や外観の変化を時間的に追跡し、記録する。
・これによってオブジェクトが時間とともにどのように変化するかを理解してその情報を用いて未来のフレーム(次のシーン)でのオブジェクトの位置を予測することが可能。
実は、この二つを組み合わせてビデオトラッキングを実現した派生技術が 下記
Track Anything
このコードで一発OKなのだが、
今回はWEBカメラを使って、リアルタイムトラッキングを行うため、
Segment anythingをベースにXmemを載せる感じでコードを作った。
こちらの記事を参考にしました
※※ありがとうございます
システムの流れ
①カメラを起動
②マウスクリックでセグメンテーションポイントの選択
‐SAMモデルにて推論を行いマスク(セグメンテーション)情報を取得
マスク情報をXmenに通しトラッキングオブジェクトを定義
③トラッキングの名前を定義
④②で選択したオブジェクトのトラッキングを開始
import datetime
import numpy as np
import torch
import cv2
以下2つのモジュールは別途Githubから落としてください
from segment_anything import sam_model_registry, SamPredictor
#https://github.com/facebookresearch/segment-anything/tree/main/segment_anything
from tracker.base_tracker import BaseTracker
#https://github.com/gaomingqi/Track-Anything/tree/master/tracker
#BaseTrackerの初期化
xmem_checkpoint = 'check_pt/XMem-s012.pth'
device = "cuda"
tracker = BaseTracker(xmem_checkpoint, device)
print("PyTorch version:", torch.__version__)
print("CUDA is available:", torch.cuda.is_available())
# Segment Anythingの初期化
sam_checkpoint = "check_pt/sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
# Videoの初期化
cap = cv2.VideoCapture(1)
if cap.isOpened() is False:
raise IOError
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
_, frame = cap.read()
h, w, _ = frame.shape
# 画面全体より小さい枠
margin = 50
input_box = np.array([margin, margin, w - margin, h - margin])
# 各種変数定義
input_point = None
input_label = np.array([1])
clicked_frame = None
target_name = ""
# マウスイベントの処理関数を定義
def mouse_event(event, x, y, flags, param):
global input_point, target_name, clicked_frame
if event == cv2.EVENT_LBUTTONDOWN:
input_point = np.array([[x, y]])
target_name = input("Enter a name for this point: ")
clicked_frame = frame.copy()
# マウスイベント時に処理を行うウィンドウの名前を設定
cv2.namedWindow('Frame')
cv2.setMouseCallback('Frame', mouse_event)
print("start.")
# 最初のフレームを表示し、任意の点をクリックするまで待つ
while input_point is None:
ret, frame = cap.read()
if ret is False:
raise IOError
cv2.imshow("Frame", frame)
cv2.waitKey(1)
first_frame = True
# セグメンテーション
predictor.set_image(cv2.cvtColor(clicked_frame, cv2.COLOR_BGR2RGB))
masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_box,
multimask_output=False,
)
# トラッキング
mask, prob, painted_frame = tracker.track(clicked_frame, masks[0])
first_frame = False
while True:
try:
ret, frame = cap.read()
if ret is False:
raise IOError
# elseだけでいい
if first_frame:
mask, prob, painted_frame = tracker.track(clicked_frame, masks[0])
first_frame = False
else:
mask, prob, painted_frame = tracker.track(frame)
true_points = np.where(mask)
# 検出がなかったときように分岐
if true_points[0].size > 0 and true_points[1].size > 0:
top_left = true_points[1].min(), true_points[0].min()
bottom_right = true_points[1].max(), true_points[0].max()
color = (0, 0, 255) # red
thickness = 2
cv2.rectangle(frame, top_left, bottom_right, color, thickness)
# バウンディボックスやターゲット名をCVの画面上に描画![Something went wrong]()
text = target_name
org = (top_left[0], top_left[1] - 10)
font = cv2.FONT_HERSHEY_SIMPLEX
fontScale = 1
color = (255, 255, 255) # white
cv2.putText(frame, text, org, font, fontScale, color, thickness, cv2.LINE_AA)
cv2.imshow("Frame", frame)
cv2.imshow("Mask", mask * 255)
cv2.waitKey(1)![Something went wrong]()
except KeyboardInterrupt:
break