4
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

【MMDetection】 データ処理パイプラインのカスタマイズ編

Last updated at Posted at 2021-02-03

概要

まだまだ日本語情報が少ない物体検出フレームワーク"MMDetection"について、学んだことを記録していこうと思います。間違い等ありましたらぜひコメントで教えてください。よろしくお願いします。

前回:データセットのカスタマイズ編

今回は、データセットからデータを取得し、前処理を施してモデルに渡すまでのパイプラインをカスタマイズする方法をまとめます。

パイプラインについて

MMDetectionでは、MMCVのDataContainer型を利用してサイズ違いのデータを簡単にスタック、バッチ処理できるように工夫されています。

データセットとデータ処理のパイプラインは分離されていて、パイプラインではモデルに渡す辞書型のデータを準備するための全ステップを定義します。

以下にデータ処理パイプラインの例を示します。
(緑は新しく追加されるキー、橙は修正されるキーを示しています。)
data_pipeline.png
各操作は、data loading, pre-processing, formatting, test_time augmentationに分類されます。

パイプラインの設計

1. configの記述

パイプラインはconfigファイルに記述します。
configファイルに、train_pipline, test_pipelineをキーとした配列を定義し、所望の処理をtype名とした辞書を順に加えていきます。

以下にFaster R-CNNのパイプラインの例を示します。

img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(1333, 800),
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **img_norm_cfg),
            dict(type='Pad', size_divisor=32),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]

2. パイプラインの種類

公式ドキュメントに載っていたものを挙げてみます。
コードを見ると他にも用意されているみたいです。詳細はこちらを確認してください。

2.1. Data loading

  • LoadImageFromFile
  • LoadAnnotations
  • LoadProposals

2.2. Pre-processing

  • Resize
  • RandomFlip
  • Pad
  • RandomCrop
  • Normalize
  • SegRescale
  • PhotoMetricDistortion
  • Expand
  • MinIoURandomCrop
  • Corrupt

2.3. Formatting

  • ToTensor
  • ImageToTensor
  • Transpose
  • ToDataContainer
  • DefaultFormatBundle
  • Collect

2.4. Test time augmentation

  • MultiScaleFlipAug

3. パイプライン処理の拡張

パイプラインに含める各処理を自作することも可能です。

3.1. モジュールの登録

辞書を受け取り辞書を返すクラスを定義し、 ファイルをmmdet/datasets/pipelinesに配置します。

my_pipeline.py
from mmdet.datasets import PIPLINES

@PIPELINES.register_module()  # パイプラインモジュールを登録するおまじない
class MyTransform:

    def __call__(self, results):
        # 任意の処理
        results['dummy'] = True
        return results

3.2. __init__.py の修正

自分で新しく定義したクラスがimportされるように、mmdet/datasets/pipelines/__init__.pyを修正します。

__init__.py
# 追加
from .my_pipeline import MyDataset

# 修正
__all__ = [
    'Cumpose', ... , 'MyTransform'  # 'MyTransform'を追加
]

3.3. configファイルの記述

新しく登録したモジュールをconfigファイルで利用します。

    train_pipeline = [
        ...,
        dict(type='MyTransform'),  # type名 = 自作したクラス名
        ...
    ]

参考文献

4
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?