2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

【MMDetection】 モジュール作成編

Posted at

概要

まだまだ日本語情報が少ない物体検出フレームワーク"MMDetection"について、学んだことを記録していこうと思います。間違い等ありましたらぜひコメントで教えてください。よろしくお願いします。

前回:データ処理パイプラインのカスタマイズ編

今回は、モデルのコンポーネントとなるモジュールを自分で作ってモデルに組み込む方法をまとめます。

バックボーンの作成

MobileNetを作成してみます。

1. 新しいモジュールの定義

mmdet/models/backbones/mobilenet.pyを作成し、MobileNetを実装します。
さらに、デコレータを用いることでバックボーンモジュールとして登録することができます。

mobilenet.py
import torch.nn as nn

from ..builder import BACKBONES


@BACKBONES.register_module()  # バックボーンモジュールとして登録するためのデコレータ
class MobileNet(nn.Module):

    def __init__(self, arg1, arg2):
        # 省略
        pass

    def forward(self, x):  # should return a tuple
        # 省略
        pass

    def init_weights(self, pretrained=None):
        # 省略
        pass

2. 作成したモジュールのimport

mmdet/models/backbones/__init__.pyに以下の内容を追加する

__init__.py
from .mobilenet import MobileNet

3. モジュールを利用する

config.py
model = dict(
    ...
    backbone=dict(
        type='MobileNet',
        arg1=xxx,
        arg2=xxx),
    ...

ネックの作成

PAFPNを作成してみます。

1. 新しいモジュールの定義

mmdet/models/backbones/pafpn.pyを作成し、PAFPNを実装します。
さらに、デコレータを用いることでネックモジュールとして登録することができます。

pafpn.py
from ..builder import NECKS

@NECKS.register_module()  # ネックモジュールとして登録するためのデコレータ
class PAFPN(nn.Module):

    def __init__(self,
                in_channels,
                out_channels,
                num_outs,
                start_level=0,
                end_level=-1,
                add_extra_convs=False):
        pass

    def forward(self, inputs):
        # 省略
        pass

2. 作成したモジュールのimport

mmdet/models/necks/__init__.pyに以下の内容を追加する

__init__.py
from .pafpn import PAFPN

3. モジュールを利用する

config.py
neck=dict(
    type='PAFPN',
    in_channels=[256, 512, 1024, 2048],
    out_channels=256,
    num_outs=5)

ヘッドの作成

Double Head R-CNN を実装するために、DoubleConvFCBBoxHeadとDoubleHeadRoIHeadを作成します。

1. 新しいモジュールの定義

mmdet/models/roi_heads/bbox_heads/double_bbox_head.pyを作成し、 DoubleConvFCBBoxHeadを実装します。
デコレータを用いることでヘッドモジュールとして登録することができます。

double_bbox_head.py
from mmdet.models.builder import HEADS
from .bbox_head import BBoxHead

@HEADS.register_module()
class DoubleConvFCBBoxHead(BBoxHead):

    def __init__(self,
                 num_convs=0,
                 num_fcs=0,
                 conv_out_channels=1024,
                 fc_out_channels=1024,
                 conv_cfg=None,
                 norm_cfg=dict(type='BN'),
                 **kwargs):
        kwargs.setdefault('with_avg_pool', True)
        super(DoubleConvFCBBoxHead, self).__init__(**kwargs)

    def init_weights(self):
        # conv layers are already initialized by ConvModule

    def forward(self, x_cls, x_reg):
        # 省略

mmdet/models/roi_heads/double_roi_head.pyを作成し、 DoubleHeadRoIHeadを実装します。
デコレータを用いることでヘッドモジュールとして登録することができます。

double_roi_head.py
from ..builder import HEADS
from .standard_roi_head import StandardRoIHead


@HEADS.register_module()
class DoubleHeadRoIHead(StandardRoIHead):

    def __init__(self, reg_roi_scale_factor, **kwargs):
        super(DoubleHeadRoIHead, self).__init__(**kwargs)
        self.reg_roi_scale_factor = reg_roi_scale_factor

    def _bbox_forward(self, x, rois):
        # 省略
        pass

2. 作成したモジュールのimport

mmdet/models/bbox_heads/__init__.pymmdet/models/roi_heads/__init__.pyを修正してモジュールをimportする。

3. モジュールを利用する

config.py
_base_ = '../faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
model = dict(
    roi_head=dict(
        type='DoubleHeadRoIHead',
        reg_roi_scale_factor=1.3,
        bbox_head=dict(
            _delete_=True,
            type='DoubleConvFCBBoxHead',
            num_convs=4,
            num_fcs=2,
            in_channels=256,
            conv_out_channels=1024,
            fc_out_channels=1024,
            roi_feat_size=7,
            num_classes=80,
            bbox_coder=dict(
                type='DeltaXYWHBBoxCoder',
                target_means=[0., 0., 0., 0.],
                target_stds=[0.1, 0.1, 0.2, 0.2]),
            reg_class_agnostic=False,
            loss_cls=dict(
                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=2.0),
            loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=2.0))))

Lossの作成

BBox回帰のためのMyLossを作成してみます。

1. 新しいLossの定義

mmdet/models/losses/my_loss.pyを作成し、MyLossを実装します。
weighted_lossデコレータを用いると、重み付き損失関数を簡単に作ることができます。

my_loss.py
import torch
import torch.nn as nn

from ..builder import LOSSES
from .utils import weighted_loss

@weighted_loss
def my_loss(pred, target):
    assert pred.size() == target.size() and target.numel() > 0
    loss = torch.abs(pred - target)
    return loss

@LOSSES.register_module()
class MyLoss(nn.Module):

    def __init__(self, reduction='mean', loss_weight=1.0):
        super(MyLoss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        loss_bbox = self.loss_weight * my_loss(
            pred, target, weight, reduction=reduction, avg_factor=avg_factor)
        return loss_bbox

2. 作成したLossのimport

mmdet/models/losses/__init__.pyに以下の内容を追加する

__init__.py
from .my_loss import MyLoss, my_loss

3. Lossを利用する

config.py
loss_bbox=dict(type='MyLoss', loss_weight=1.0))

参考文献

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?