LoginSignup
8
6

More than 1 year has passed since last update.

mmdetection 完全理解マニュアル (推論編)

Last updated at Posted at 2022-08-18

この記事の目的

この記事の目的は mmdetection の推論ができるようになることです。推論ができるとは具体的に以下のような状態を指します。

  1. 学習済み重みと config をロードしたモデルインスタンスを作成する
  2. モデルの推論を行う
  3. 推論結果を可視化する

なお、mmdetection のバージョンは 2.25.1 を想定しています。

mmdetection を使って推論をする

1. 学習済み重みと config をロードしたモデルインスタンスを作成する

モデルの初期化は mmdet.apis.init_detector によって行えます。
mmdet.apis.init_detector は3つの引数1を持っています。

  • config: モデル・パイプラインの定義ファイルです。config ファイルのパス / mmcv.Config インスタンスを渡すことができます2
  • checkpoint: モデルの重みです。重みファイルのパス / URL / s3 アドレス など3として渡すことができます
  • device: 推論時のデバイスを指定します。デフォルトだと 'cuda:0' が使われます

init_detector の戻り値は引数で指定したモデルのインスタンス (nn.Module) です。

from mmdet.apis import init_detector


# set appropriate paths
config_file = 'path/to/config'
checkpoint_file = 'path/to/checkpoint'

# initialize model
model = init_detector(config_file, checkpoint_file, device='cuda:0')

2.モデルの推論を行う

mmdetection のモデル推論は mmdet.apis.inference_detector によって行えます。inference_detector の引数は2つあります。

  • model: モデルインスタンス。init_detector で初期化したもの
  • imgs: 入力画像。以下のいずれか
    • 画像のパス (str or pathlib.Path)
    • 画像の配列 (numpy.ndarray)
      • numpy.uint8 か  numpy.float32 である必要があります
    • 画像のパスないし配列のリスト (list)
    • 画像のパスないし配列のタプル (tuple)

inference_detector の戻り値はモデルの種類と入力の種類によって分岐します。

モデルの種類 input: img input: Iterable[img]
object detection List[numpy.ndarray] List[List[numpy.ndarray]]
instance segmentation List[List[numpy.ndarray]] List[List[List[numpy.ndarray]]]
panoptic segmentation dict List[dict]

いずれのケースでも入力をリスト or タプルで渡すとバッチデータとして扱われ、外側にリストが1つ増えます。バッチデータでない時にはそれぞれの戻り値は以下のような内容になっています。

  • object detection
    • 長さがクラス数と等しいリストです
    • i番目の要素はクラスiの bbox の配列になっています
    • 配列の形状は (num_bboxes, 5) で、axis=1 の内訳は [x1, y1, x2, y2, confidence] です
  • instance segmentation
    • 長さが2のリストです
      • 0番目に物体検出の結果 (上記 object detection の内容と同じ) ものが入っています
      • 1番目にインスタンスセグメンテーションの結果が List[numpy.ndarray] で入っています
        • i番目の要素はクラスiのセグメンテーションマスクの配列になっています
        • 配列の形状は (img_height, img_width) です
          • img_width, img_height は入力に渡した画像のサイズと同じものです
        • 配列の中身は bool の値が入っていて、インスタンスの領域が True そうでない領域が False です
  • panoptic segmentation
    • キーが2つある辞書です。以下のキーを持っています
      • pan_results
        • panoptic segmentation の結果が numpy.ndarray で入っています
        • 配列の形状は (img_height, img_width) です
        • 配列の要素はセグメンテーションクラスを表す整数値です
      • ins_results
        • インスタンスセグメンテーションの結果 (上記 instance segmentation の内容と同じ) が入っています
from mmdet.apis import inference_detector


img = 'path/to/image'
result = inference_detector(model, img)

3. 推論結果を可視化する

推論した結果は model.show_result で可視化できます。この時、可視化は1枚ずつしかできないことに注意してください4show_result の引数は4つ5あります。

  • img: 画像のパスないし配列です
  • result: inference_detector の戻り値です
  • iou_thrs: 描画するオブジェクトの confidence の閾値です。デフォルト値は 0.3 です
  • out_file: 描画した結果を書き出すファイルのパスです。デフォルトでは None になっており、描画結果は保存されません。out_file を指定すると描画した配列オブジェクトが return されなくなります

show_result の戻り値は out_file を指定しない場合には bbox / instance mask / panoptic mask が描画された配列です。out_file を指定した場合には None になります。

# visualize result and get the array
vis = model.show_result(img, result, iou_thrs=0.3)
# plt.imshow(vis) / plt.savefig('path/to/save', vis)

# save visualized result directory
model.show_result(img, result, iou_thrs=0.3, out_file='path/to/save')

メモ

学習編・実装編もそのうち記事を書きます。

  1. 本当は4つありますが本記事では省略します。

  2. config のパスを渡すと内部的に mmcv.Config に変換されて利用されます。インスタンスを渡す意味はほぼないです。たぶん。

  3. その他、modelzoo / torchvision / open-mmlab / openmmlab / mmcls のアドレスも使えます。

  4. バッチで推論したデータを可視化するときはループを回して1枚ずつ処理する必要があります。

  5. 実際には bbox / mask の色や bbox の線の太さなどを調整する描画パラメータが他にもあります。詳しくは公式の docstring を参照してください。

8
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
6