7
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Detectron2にバッチ単位で処理させる方法

Last updated at Posted at 2021-09-21

Detectron2を使ってバッチ処理を行う

Detectron2とは

Detectron2はfacebookが公開している物体検出ライブラリです。

  • 情報源
  • Detectron2の簡単な使い方はgoogle colabのチュートリアルで説明されています。Detectorn2でバッチを処理を行う前に一度このチュートリアルに目を通すことをお勧めします。このチュートリアルでは学習ループや評価コード内で呼び出す方法が説明されておらず公式リファレンスでしか説明されていません。そのためDataloaderから読み出した画像のテンソルのバッチを学習ループ内で利用する場合はチュートリアルの方法ではできません。
  • また、チュートリアルの方法では入力をuint8のnp.ndarry(channel, height, width)の形でないと動きません(実際はuint8でなくても動くそうですが、入力をfloatなどにして試したところ検出がうまくできていなかったのでuint8にすることをお勧めします)。

インストール方法

###必要環境

  • Linuxのpython3.6以上
  • Pytorch 1.7以上、Pytorchのバージョンに対応するtorchvision
  • OpenCV(検出結果を可視化する場合)

###インストール
このコマンドを実行する

$ python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'

詳しくは公式ドキュメントでインストールについて説明されています。またDetectron2はmacOSも対応しているのでmacでのインストールをする場合も同様にリンクを参考にしてください。

モデルの生成

ライブラリをimportします.

import torch, torchvision

import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

import numpy as np
import os, json, cv2, random

from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.modeling import build_model
from detectron2.checkpoint import DectionCheckpointer

detectron2.config .get_configを用いて生成したいモデルの設定を決めるには次のようにします.

  • cfg.merge_from_file(model_zoo.get_config_file(好きなモデルのyamlファイル))
  • cfg.MODEL.WEIGHTS = model_zoo.get_checkoutpoint_url(好きなモデルのyamlファイル)

ここで生成したいモデルを決めます。yamlファイルはgithubのconfigのディレクトリにあるファイル名のものを使用してくだいさい。またこの2つは同じファイル名にしてください。

検出の閾値の設定

  • cfg.MODEL.RETINANET.SCORE_THRESH_TEST = 好きな閾値
  • cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 好きな閾値
  • cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = 好きな閾値

モデルとタスクによって閾値を代入する変数名が変わるので使用するモデルに合わせて閾値の変数を決めてください

今回はCOCO_InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yamlを使った場合の例です

detectron2の設定
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")

cfgに設定したモデルを生成します。
今回は学習をしないのでeval()を使います。

モデルの作成
model = build_model(cfg)
model.eval()
checkpointer = DetectionCheckpointer(model)
checkpointer.load(cfg.MODEL.WEIGHTS)

モデルにはバッチのテンソルではなくリストを与える

今回はDataloaderによってバッチが呼び出されたものとします。

  • このバッチはtorch.Tensor(float32)の(batch, channel, height, width)の形になっていて、前処理として全て同じ大きさにリサイズされていて正規化されていないものとします。
  • カラー画像のchannelの順序はOpenCVと同じBGRとします。
  • 生成されたモデルにはDict["image":torch.Tensor, "height":int, "width":int]リストを入力とします。
    • "image"には画像のtorch.Tensor"height""width"には出力の大きさを指定します。

このTensor中の画像を一つずつ取り出し,リストにして,モデルに入力します.

with torch.no_grad():
    input_list = []
    for image in batch:  # batchはDataloaderから取り出したバッチ
        input = {"image": image, 
                 "height": height,
                 "width": width}  # heightとwidthは出力の大きさ
        input_list.append(input)
    outputs = model(input_list)

出力されたoutputs公式ドキュメントで説明されているフォーマットのリストが出力されます.

7
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?