MMDetectionでモデルを学習させる方法
たくさんモデルが使えて、State of art(現状最高)のモデルとかもMMDetectionで動いているで有名なMMDetectionですが、
バージョンアップですぐ動かなくなることでも僕の中で有名です。
この記事では、ちゃんとMMDetectionを動かして、自分のデータでモデルを学習させます。
Install
たのむよ
まあバージョンアップごとにいろんなエラーが出て、
うまくうごかないんだけど、
とりあえずこれでうごいたで。
2024/11/16の時点で。
動くセットアップ
インストール。
mimではなくpipでインストールすることと、mmcvのバージョンがポイント。
pip install mmcv==2.1.0
git clone https://github.com/open-mmlab/mmdetection.git
cd mmdetection
pip install -v -e .
インストールに30分以上かかるけどめげずにがんばろう。
mmdetection/hack_registry.pyというファイルを作る。
# hack_registry.py
import logging
from mmengine.registry import Registry
from mmengine.logging import print_log
from typing import Type, Optional, Union, List
def _register_module(self,
module: Type,
module_name: Optional[Union[str, List[str]]] = None,
force: bool = False) -> None:
"""Register a module.
Args:
module (type): Module to be registered. Typically a class or a
function, but generally all ``Callable`` are acceptable.
module_name (str or list of str, optional): The module name to be
registered. If not specified, the class name will be used.
Defaults to None.
force (bool): Whether to override an existing class with the same
name. Defaults to False.
"""
if not callable(module):
raise TypeError(f'module must be Callable, but got {type(module)}')
if module_name is None:
module_name = module.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and name in self._module_dict:
existed_module = self.module_dict[name]
# raise KeyError(f'{name} is already registered in {self.name} '
# f'at {existed_module.__module__}')
print_log(
f'{name} is already registered in {self.name} '
f'at {existed_module.__module__}. Registration ignored.',
logger='current',
level=logging.INFO
)
self._module_dict[name] = module
Registry._register_module = _register_module
mmdetのimportの前にhack_registryをimport。
import hack_registry
from mmdet.apis import DetInferencer
これでやっと動く。
たのむよ、mmlabさん。
データセットの用意
COCO形式のデータセットを用意する。
私はYolo形式のデータセットからCOCO形式に変換したので、以下のスクリプトを用いました。
他の形式の方はがんばってください😉
データの移動には以下のコマンドが高速です。shutilだと大量のデータは重い。
source_dir = "/content/drive/MyDrive/datasets/defect_crop_car/labels/train"
destination_dir = "/content/defect_crop_car/labels/train"
!mkdir -p "{destination_dir}"
!find "{source_dir}" -type f | xargs -P 8 -I{{}} cp {{}} "{destination_dir}"
import json
import collections as cl
import cv2
def get_info():
tmp = cl.OrderedDict()
tmp["description"] = "my dataset"
tmp["url"] = ""
tmp["version"] = "1"
tmp["year"] = 2024
tmp["contributor"] = ""
tmp["data_created"] = "2024/11/17"
return tmp
def get_licenses():
tmp = cl.OrderedDict()
tmp["id"] = 0
tmp["url"] = ""
tmp["name"] = ""
return [tmp]
def get_image_data(id, image_path, image_file, w, h):
tmp = cl.OrderedDict()
tmp["license"] = 0
tmp["id"] = id
tmp["file_name"] = image_file
tmp["width"] = w
tmp["height"] = h
tmp["date_captured"] = ""
tmp["coco_url"] = ""
tmp["flickr_url"] = ""
return tmp
def get_annotation(id, image_id, category_id, bbox):
tmp = cl.OrderedDict()
tmp_segmentation = cl.OrderedDict()
tmp_segmentation = [bbox]
tmp["segmentation"] = tmp_segmentation
tmp["id"] = id
tmp["image_id"] = image_id
tmp["category_id"] = category_id
tmp["area"] = bbox[2]*bbox[3]
tmp["iscrowd"] = 0
tmp["bbox"] = bbox
return tmp
def get_categories():
tmps = []
sup = ["damage"]
classes = ["your_class0", "your_class1"]
for i in range(len(classes)):
tmp = cl.OrderedDict()
tmp["id"] = i
tmp["supercategory"] = sup
tmp["name"] = classes[i]
tmps.append(tmp)
return tmps
def create_dataset(image_dir, label_dir, json_path):
image_list = os.listdir(image_dir)
print(len(image_list))
info = get_info()
licenses = get_licenses()
categories = get_categories()
images = []
annotations = []
for i, image_file in enumerate(image_list):
# for i in range(100):
image_file = image_list[i]
image_path = os.path.join(image_dir,image_file)
label_file = os.path.splitext(os.path.basename(image_file))[0] + ".txt"
label_path = os.path.join(label_dir,label_file)
if os.path.exists(label_path):
img = cv2.imread(image_path)
img_h, img_w, _ = img.shape
image_data = get_image_data(i, image_path, image_file, img_w, img_h)
source_file = open(label_path)
for defect_index, line in enumerate(source_file):
staff = line.split()
class_idx = int(staff[0])
x_center, y_center, width, height = float(staff[1])*img_w, float(staff[2])*img_h, float(staff[3])*img_w, float(staff[4])*img_h
x = round(x_center-width/2,2)
y = round(y_center-height/2,2)
width = round(width,2)
height = round(height,2)
bbox = [x, y, width, height]
id = i * 1000 + defect_index
annotation = get_annotation(id, i, class_idx, bbox)
annotations.append(annotation)
images.append(image_data)
json_data = {
'info': info,
'images': images,
'licenses': licenses,
'annotations': annotations,
'categories': categories,
}
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(json_data, f, ensure_ascii=False)
import os
dataset_root_dir = "my_dataset/"
image_dir = os.path.join(dataset_root_dir,"images/")
label_dir = os.path.join(dataset_root_dir,"labels/")
image_train_dir = os.path.join(image_dir, "train")
image_val_dir = os.path.join(image_dir, "val")
label_train_dir = os.path.join(label_dir, "train")
label_val_dir = os.path.join(label_dir, "val")
dataset_name = os.path.basename(os.path.normpath(dataset_root_dir))
train_json_path = os.path.join(dataset_root_dir, "train.json")
val_json_path = os.path.join(dataset_root_dir, "val.json")
create_dataset(image_train_dir, label_train_dir, train_json_path)
create_dataset(image_val_dir, label_val_dir, val_json_path)
これでdatasetのディレクトリにtrain.json, val.jsonができます。
configの作成と編集
_base_ = '/content/mmdetection/projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_16xb1_1x_coco.py'
classes = ("class0","class1")
data_root = 'my_dataset'
train_dataloader = dict(
batch_size=1,
dataset=dict(
ann_file='train.json',
data_prefix=dict(img='/images/train/')
),
)
val_dataloader = dict(
batch_size=1,
dataset=dict(
ann_file='my_dataset/val.json',
data_prefix=dict(img='images/val/')
),
)
test_dataloader = dict(
batch_size=1,
dataset=dict(
ann_file='my_dataset/val.json',
data_prefix=dict(img='images//val/')
),
)
test_evaluator = dict(
ann_file='my_dataset/val.json',
)
val_evaluator = dict(
ann_file='my_dataset/val.json',
)
load_from = 'co_dino_5scale_swin_large_1x_coco-27c13da4.pth'
mmdetection/mmdet/dataset/coco.py
のCocoDatasetのMETAINFOをカスタムデータセットのクラスに置き換える。
@DATASETS.register_module()
class CocoDataset(BaseDetDataset):
"""Dataset for COCO."""
METAINFO = {
'classes': ("class0", "class1"),
'palette': [(220, 20, 60), (119, 11, 32)]
}
使用するconfigの_base_ configを辿って、num_classesを書き換える。
僕の例では、
base = '/content/mmdetection/projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_16xb1_1x_coco.py'
のconfig ファイルのさらに_base_が
/content/mmdetection/projects/CO-DETR/configs/codino/co_dino_5scale_r50_lsj_8xb2_1x_coco.py
なので、それのnum_classesを自分のデータセットのクラス数にします。
num_classes = 2
これで学習を開始できます。
python tools/train.py my_config.py --work-dir /content/drive/MyDrive/codino
--work-dirに指定したパスに自動で結果保存dirが作成されます。
この場合、codinoは実行スクリプトでmakedirされます。
データ数が多い時は、小量のテストデータを作ってトレーニングが正しく1エポック実行されるか確認してから全データを学習すると良いと思います。全データで1epoch待ってエラー出たら解決に時間がかかる。
🐣
フリーランスエンジニアです。
お仕事のご相談こちらまで
rockyshikoku@gmail.com
機械学習関連の情報を発信しています。