0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

MMDetectionでモデルを学習させる方法

たくさんモデルが使えて、State of art(現状最高)のモデルとかもMMDetectionで動いているで有名なMMDetectionですが、
バージョンアップですぐ動かなくなることでも僕の中で有名です。

この記事では、ちゃんとMMDetectionを動かして、自分のデータでモデルを学習させます。

Install

たのむよ

まあバージョンアップごとにいろんなエラーが出て、
うまくうごかないんだけど、
とりあえずこれでうごいたで。
2024/11/16の時点で。

動くセットアップ

インストール。
mimではなくpipでインストールすることと、mmcvのバージョンがポイント。

pip install mmcv==2.1.0
git clone https://github.com/open-mmlab/mmdetection.git
cd mmdetection
pip install -v -e .

インストールに30分以上かかるけどめげずにがんばろう。

mmdetection/hack_registry.pyというファイルを作る。

hack_registry.py
#  hack_registry.py
import logging

from mmengine.registry import Registry
from mmengine.logging import print_log
from typing import Type, Optional, Union, List


def _register_module(self,
                     module: Type,
                     module_name: Optional[Union[str, List[str]]] = None,
                     force: bool = False) -> None:
    """Register a module.

    Args:
        module (type): Module to be registered. Typically a class or a
            function, but generally all ``Callable`` are acceptable.
        module_name (str or list of str, optional): The module name to be
            registered. If not specified, the class name will be used.
            Defaults to None.
        force (bool): Whether to override an existing class with the same
            name. Defaults to False.
    """
    if not callable(module):
        raise TypeError(f'module must be Callable, but got {type(module)}')

    if module_name is None:
        module_name = module.__name__
    if isinstance(module_name, str):
        module_name = [module_name]
    for name in module_name:
        if not force and name in self._module_dict:
            existed_module = self.module_dict[name]
            # raise KeyError(f'{name} is already registered in {self.name} '
            #                f'at {existed_module.__module__}')
            print_log(
                f'{name} is already registered in {self.name} '
                f'at {existed_module.__module__}. Registration ignored.',
                logger='current',
                level=logging.INFO
            )
        self._module_dict[name] = module


Registry._register_module = _register_module

mmdetのimportの前にhack_registryをimport。

import hack_registry

from mmdet.apis import DetInferencer

これでやっと動く。
たのむよ、mmlabさん。

データセットの用意

COCO形式のデータセットを用意する。
私はYolo形式のデータセットからCOCO形式に変換したので、以下のスクリプトを用いました。
他の形式の方はがんばってください😉

データの移動には以下のコマンドが高速です。shutilだと大量のデータは重い。

source_dir = "/content/drive/MyDrive/datasets/defect_crop_car/labels/train"
destination_dir = "/content/defect_crop_car/labels/train"  

!mkdir -p "{destination_dir}"

!find "{source_dir}" -type f | xargs -P 8 -I{{}} cp {{}} "{destination_dir}"
import json
import collections as cl
import cv2

def get_info():
  tmp = cl.OrderedDict()
  tmp["description"] = "my dataset"
  tmp["url"] = ""
  tmp["version"] = "1"
  tmp["year"] = 2024
  tmp["contributor"] = ""
  tmp["data_created"] = "2024/11/17"
  return tmp

def get_licenses():
  tmp = cl.OrderedDict()
  tmp["id"] = 0
  tmp["url"] = ""
  tmp["name"] = ""
  return [tmp]

def get_image_data(id, image_path, image_file, w, h):
  tmp = cl.OrderedDict()
  tmp["license"] = 0
  tmp["id"] = id
  tmp["file_name"] = image_file
  tmp["width"] = w
  tmp["height"] = h
  tmp["date_captured"] = ""
  tmp["coco_url"] = ""
  tmp["flickr_url"] = ""
  return tmp

def get_annotation(id, image_id, category_id, bbox):
  tmp = cl.OrderedDict()

  tmp_segmentation = cl.OrderedDict()
  tmp_segmentation = [bbox]
  tmp["segmentation"] = tmp_segmentation
  tmp["id"] = id
  tmp["image_id"] = image_id
  tmp["category_id"] = category_id
  tmp["area"] = bbox[2]*bbox[3]
  tmp["iscrowd"] = 0
  tmp["bbox"] =  bbox
  return tmp

def get_categories():
  tmps = []
  sup = ["damage"]
  classes = ["your_class0", "your_class1"]
  for i in range(len(classes)):
    tmp = cl.OrderedDict()
    tmp["id"] = i
    tmp["supercategory"] = sup
    tmp["name"] = classes[i]
    tmps.append(tmp)
  return tmps

def create_dataset(image_dir, label_dir, json_path):
    image_list = os.listdir(image_dir)
    print(len(image_list))
    info = get_info()
    licenses = get_licenses()
    categories = get_categories()
    images = []
    annotations = []
    
    for i, image_file in enumerate(image_list):
    # for i in range(100):
      image_file = image_list[i]
      image_path = os.path.join(image_dir,image_file)
      label_file = os.path.splitext(os.path.basename(image_file))[0] + ".txt"
      label_path = os.path.join(label_dir,label_file)
      if os.path.exists(label_path):
        img = cv2.imread(image_path)
        img_h, img_w, _ = img.shape
    
        image_data = get_image_data(i, image_path, image_file, img_w, img_h)
        source_file = open(label_path)
        for defect_index, line in enumerate(source_file):
          staff = line.split()
          class_idx = int(staff[0])
    
          x_center, y_center, width, height = float(staff[1])*img_w, float(staff[2])*img_h, float(staff[3])*img_w, float(staff[4])*img_h
          x = round(x_center-width/2,2)
          y = round(y_center-height/2,2)
          width = round(width,2)
          height = round(height,2)
          bbox = [x, y, width, height]
          id = i * 1000 + defect_index
          annotation = get_annotation(id, i, class_idx, bbox)
          annotations.append(annotation)
          images.append(image_data)
    
    json_data = {
        'info': info,
        'images': images,
        'licenses': licenses,
        'annotations': annotations,
        'categories': categories,
    }
    
    with open(json_path, 'w', encoding='utf-8') as f:
        json.dump(json_data, f, ensure_ascii=False)
import os

dataset_root_dir = "my_dataset/"
image_dir = os.path.join(dataset_root_dir,"images/")
label_dir = os.path.join(dataset_root_dir,"labels/")

image_train_dir = os.path.join(image_dir, "train")
image_val_dir = os.path.join(image_dir, "val")
label_train_dir = os.path.join(label_dir, "train")
label_val_dir = os.path.join(label_dir, "val")

dataset_name = os.path.basename(os.path.normpath(dataset_root_dir))
train_json_path = os.path.join(dataset_root_dir, "train.json")
val_json_path = os.path.join(dataset_root_dir, "val.json")

create_dataset(image_train_dir, label_train_dir, train_json_path)
create_dataset(image_val_dir, label_val_dir, val_json_path)

これでdatasetのディレクトリにtrain.json, val.jsonができます。

configの作成と編集

my_config.py
_base_ = '/content/mmdetection/projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_16xb1_1x_coco.py' 

classes = ("class0","class1")
data_root = 'my_dataset'

train_dataloader = dict(
    batch_size=1,
    dataset=dict(
        ann_file='train.json',
        data_prefix=dict(img='/images/train/')
    ),
)
val_dataloader = dict(
    batch_size=1,
    dataset=dict(
        ann_file='my_dataset/val.json',
        data_prefix=dict(img='images/val/')
    ),
)

test_dataloader = dict(
    batch_size=1,
    dataset=dict(
        ann_file='my_dataset/val.json',
        data_prefix=dict(img='images//val/')
    ),
)

test_evaluator = dict(
    ann_file='my_dataset/val.json',
)

val_evaluator = dict(
    ann_file='my_dataset/val.json',
)

load_from = 'co_dino_5scale_swin_large_1x_coco-27c13da4.pth'

mmdetection/mmdet/dataset/coco.py
のCocoDatasetのMETAINFOをカスタムデータセットのクラスに置き換える。

coco.py
@DATASETS.register_module()
class CocoDataset(BaseDetDataset):
    """Dataset for COCO."""

    METAINFO = {
      'classes': ("class0", "class1"),
      'palette': [(220, 20, 60), (119, 11, 32)]
    }

使用するconfigの_base_ configを辿って、num_classesを書き換える。
僕の例では、

base = '/content/mmdetection/projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_16xb1_1x_coco.py'
のconfig ファイルのさらに_base_が
/content/mmdetection/projects/CO-DETR/configs/codino/co_dino_5scale_r50_lsj_8xb2_1x_coco.py
なので、それのnum_classesを自分のデータセットのクラス数にします。

co_dino_5scale_r50_lsj_8xb2_1x_coco.py
num_classes = 2

これで学習を開始できます。

python tools/train.py my_config.py --work-dir /content/drive/MyDrive/codino

--work-dirに指定したパスに自動で結果保存dirが作成されます。
この場合、codinoは実行スクリプトでmakedirされます。

データ数が多い時は、小量のテストデータを作ってトレーニングが正しく1エポック実行されるか確認してから全データを学習すると良いと思います。全データで1epoch待ってエラー出たら解決に時間がかかる。

🐣


フリーランスエンジニアです。
お仕事のご相談こちらまで
rockyshikoku@gmail.com

機械学習関連の情報を発信しています。

Twitter
Medium
GitHub

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?