LoginSignup
5
5

More than 1 year has passed since last update.

Detectron2を使った自作データの学習 ~キーポイント検出~

Last updated at Posted at 2023-01-01

1.はじめに

Detectron2とは、Facebook AIが開発した、PyTorchベースの物体検出のライブラリです。今回はDetectron2を用いた自作データの学習と題して、犬のキーポイント検出を行っていこうと思います。

作業環境としてはGoogle Colabを利用します。

2.キーポイント検出とは

キーポイント検出とは、画像に写っている人などの対象物の特徴点を検出することです。
よく使われるのが、人体の手足等のポイントを検出することで、その人の姿勢推定を行う例があります。

  • 参考画像

image.png

人のキーポイント検出を行う学習モデルは多数ありますが、新たな物体のキーポイント検出をする際は、自作データで学習を行う必要があります。

今回は"犬"のキーポイント検出を行っていこうと思います。

3.方法

3-1. 学習データ作成

アノテーションにはCOCO Annotatorを用います。
詳細はこちらのページを参考にしました。

犬の頭、首、右前足、左前足、お尻、右後足、左後足の7点をキーポイントとして設定すると、こんな感じになります。
image.png

3-2. 学習

まずdetectron2をインストールします。
公式チュートリアルを参考にしましょう。


!python -m pip install pyyaml==5.1
import sys, os, distutils.core
!git clone 'https://github.com/facebookresearch/detectron2'
dist = distutils.core.run_setup("./detectron2/setup.py")
!python -m pip install {' '.join([f"'{x}'" for x in dist.install_requires])}
sys.path.insert(0, os.path.abspath('./detectron2'))

下記コードが実行ができれば、正常にインストールされています。

import torch, detectron2
!nvcc --version
TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)
print("detectron2:", detectron2.__version__)

結果

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Sun_Feb_14_21:12:58_PST_2021
Cuda compilation tools, release 11.2, V11.2.152
Build cuda_11.2.r11.2/compiler.29618528_0
torch:  1.12 ; cuda:  cu113
detectron2: 0.6

importするライブラリは下記の通りです。

# Some basic setup:
# Setup detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

# import some common libraries
import numpy as np
import os, json, cv2, random
from google.colab.patches import cv2_imshow

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog

データセットを登録します。
下記コードでdog-projectという名前でデータが登録され、今後データを読み出す際はこのdog-projectという名前を利用します。

from detectron2.data.datasets import register_coco_instances
# register_coco_instances(name, metadata, json_file, image_root) # parameter
register_coco_instances("dog-project", {}, "/content/drive/MyDrive/data/dog/test.json", "/content/drive/MyDrive/data/dog")

データセットの詳細を設定します。
Keypoint検出に特有な項目が下記3点です。
①keypoint_names
設定したkeypointの名称を順番にリスト形式で渡します。
list[str]

②keypoint_flip_map
画像を水平方向反転した際に、ラベルを入れ替える項目を設定します。
データオーギュメンテーションの際に用いられます。
ex.) 人:右手と左手、右足と左足
list[tuple[str]]

③keypoint_connection_rules
各キーポイントの対応と表示する際の色を(R,G,B)で渡します。
list[tuple(str, str, (r, g, b))]

*参照

詳細設定の際は、
MetadataCatalog.get("dog-project(*register_coco_instancesで設定した名前)")
を使用しましょう。

from detectron2.data import MetadataCatalog

keypoint_names = ['head', 'neck', 'front-right', 'front-left', 'hip', 'hind-right', 'hind-left']
keypoint_flip_map = [('front-right','front-left'),('hind-right','hind-left')]
keypoint_connection_rules = [('head','neck',(128,0,0)),('neck','front-right',(0,128,0)),
                             ('neck','front-left',(0,0,128)),('neck','hip',(255,0,0)),('hip','hind-right',(0,255,0)),
                             ('hip','hind-left',(0,0,255))]

MetadataCatalog.get("dog-project").thing_classes = ["Dog"]
MetadataCatalog.get("dog-project").thing_dataset_id_to_contiguous_id = {1:0}
MetadataCatalog.get("dog-project").keypoint_names = keypoint_names
MetadataCatalog.get("dog-project").keypoint_flip_map = keypoint_flip_map
MetadataCatalog.get("dog-project").keypoint_connection_rules = keypoint_connection_rules

読み込んだデータを確認します。

dog_metadata = MetadataCatalog.get("dog-project")
dataset_dicts = DatasetCatalog.get("dog-project")

for d in random.sample(dataset_dicts, 3):
  img = cv2.imread(d["file_name"])
  visualizer = Visualizer(img[:, :, ::-1], metadata=dog_metadata, scale=1.0)
  vis = visualizer.draw_dataset_dict(d)
  cv2_imshow(vis.get_image()[:, :, ::-1])

こんな形で画像が表示されればOKです。
image.png

それでは学習を行っていきます。

注意点は下記2点です。
①全体のクラス数設定
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
cfg.MODEL.RETINANET.NUM_CLASSES = 1

ここの部分は自分が設定したインスタンスのクラス数(ここでは犬のみなので1)を設定しましょう。

②keypointの数
cfg.MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS = 7
cfg.TEST.KEYPOINT_OKS_SIGMAS = np.ones((7,1),dtype=float).tolist()

ここの部分は自分が設定したkeypointの数(ここでは"head", "neck", "fornt-right", "front-left", "hip", "hind-right", "hind-left"の7)を設定しましょう。

参照

OKSについての説明はこちらを参照

from detectron2.engine import DefaultTrainer

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Keypoints/keypoint_rcnn_R_50_FPN_1x.yaml"))
cfg.DATASETS.TRAIN = ("dog-project",)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Keypoints/keypoint_rcnn_R_50_FPN_1x.yaml")  # Let training initialize from model zoo
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025  # pick a good LR
cfg.SOLVER.MAX_ITER = 2000    # 300 iterations seems good enough for this toy dataset; you will need to train longer for a practical dataset
cfg.SOLVER.STEPS = []        # do not decay learning rate
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128   # faster, and good enough for this toy dataset (default: 512)
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
cfg.MODEL.RETINANET.NUM_CLASSES = 1
cfg.MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS = 7
cfg.TEST.KEYPOINT_OKS_SIGMAS = np.ones((7,1),dtype=float).tolist()
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()

学習が実行されます。

3-3. 推論

学習されたモデルは、"./output/model_final.pth"に保存されます。
推論用のモデルはそこを参照しましょう。

cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "/content/drive/MyDrive/output/model_final.pth")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7
cfg.MODEL.DEVICE = "cpu"
predictor = DefaultPredictor(cfg)

学習データにない画像でテストしてみます。

imgPath = "/content/drive/MyDrive/data/dog/test.png"
im = cv2.imread(imgPath)
    
outputs = predictor(im)
v = Visualizer(im[:, :, ::-1],
               metadata=dog_metadata, 
               scale=1.0
)
v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
cv2_imshow(v.get_image()[:, :, ::-1])

実行結果がこちら。

image.png

右前足(濃い緑の線)がおかしなところにいっていますが、おおよそ良いのではないでしょうか。今回は学習データ15枚なので、もっと増やせば精度あがると思います。

4. まとめ

今回はDetectron2を使った自作データでのキーポイント検出を試してみました。非常に手軽にできるため、色々なことに応用できそうですね。

5. 参考

参考YOU TUBE動画
https://www.youtube.com/watch?v=xX9fPmclEHY

5
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
5