1
1

MMdetectionのRTMdetで学習(カスタムデータ)

Last updated at Posted at 2024-05-20

環境

・Windows11
・RTX 3060Ti
・CUDA11.8
・Python3.8

・環境構築はこちら
https://qiita.com/kappanda/items/46e94507f5a01abf3b14

・MMdetectionでRTMdetのモデル使用

データセット

・カスタムデータセット(1クラス)
・データ形式はcoco
・フォルダ構造
  ├─coco
  │ ├─annotations
  │ │ ├─instances_train2017.json
  │ │ └─instances_val2017.json
  │ ├─train2017(train用画像のフォルダ)
  │ └─val2017(valid用画像のフォルダ)

事前学習済みモデルの取得

mmdetection/pretrained_weightsフォルダを作成し、
そこにRTMDet-sのmodelをダウンロード
https://github.com/open-mmlab/mmdetection/tree/main/configs/rtmdet

mmdetectionのコードを一部変更してカスタムCOCOデータに対応させる

こちらを参照

Config作成

mmdetection\configs\rtmdet\rtmdet_s_8xb32-300e_coco.pyを開いて
一番下に以下を追記

rtmdet_s_8xb32-300e_coco.py
load_from = "./pretrained_weights/rtmdet_s_8xb32-300e_coco_20220905_161602-387a891e.pth"

あとはbatchsizeとかnum_classesをいじりたいがrtmdet_s_8xb32-300e_coco.pyには記載するところなし

rtmdet_s_8xb32-300e_coco.py
_base_ = './rtmdet_l_8xb32-300e_coco.py'

rtmdet_l_8xb32-300e_coco.pyのファイルがベースになっているようなのでそちらを開いて
32行目 num_classes=1に変更

rtmdet_l_8xb32-300e_coco.py
    bbox_head=dict(
        type='RTMDetSepBNHead',
        num_classes=1,
        in_channels=256,
        stacked_convs=2,
        feat_channels=256,

113行目 batch_size=4
114行目 num_workers=4
119行目 batch_size=4, num_workers=4
に変更(8GB GPUではデフォルト値だとout of memoryのため)

rtmdet_l_8xb32-300e_coco.py
train_dataloader = dict(
    batch_size=4,
    num_workers=4,
    batch_sampler=None,
    pin_memory=True,
    dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(
    batch_size=4, num_workers=4, dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader

いざ学習!

python tools/train.py configs/rtmdet/rtmdet_s_8xb32-300e_coco.py

するとエラーが出現
RuntimeError: nms_impl: implementation for device cuda:0 not found.

~~miniconda3/envs/openmmlab/Lib/site-packages/mmcv/ops/nms.pyの27行目でトラブっている模様

nms.py
inds = ext_module.nms(
    bboxes, scores, iou_threshold=float(iou_threshold), offset=offset)

ですのでGithubを参考に(https://github.com/open-mmlab/mmdetection/issues/11437)
一部変更

nms.py
# inds = ext_module.nms(
#     bboxes, scores, iou_threshold=float(iou_threshold), offset=offset)

inds = ext_module.nms(
    bboxes.to('cpu'), scores.to('cpu'), iou_threshold=float(iou_threshold), offset=offset)

再度学習!

動いた!

05/20 14:56:37 - mmengine - INFO - Epoch(train)   [1][8/8]  base_lr: 2.8068e-05 lr: 2.8068e-05  eta: 0:30:42  time: 0.7703  data_time: 0.3194  memory: 2035  loss: 2.3908  loss_cls: 1.4148  loss_bbox: 0.9759
05/20 14:56:39 - mmengine - INFO - Exp name: rtmdet_s_8xb32-300e_coco_20240520_145622
05/20 14:56:39 - mmengine - INFO - Epoch(train)   [2][8/8]  base_lr: 6.0099e-05 lr: 6.0099e-05  eta: 0:20:06  time: 0.5061  data_time: 0.1816  memory: 2035  loss: 2.3865  loss_cls: 1.4282  loss_bbox: 0.9583
1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1