はじめに
機械学習モデルを構築するところから学習する目的で、農業をテーマとする機械学習の教師あり分類で使えそうなデータセットを探していたところ、Laboro Tomatoというトマト画像物体検出データセットを見つけました。
しかし、MMDetectionのツールボックス上に学習済みモデルが構築済みで、その目的を達成することができないと分かりました。
そこで方向性を変えて、学習済みモデルを利用して画像検出結果を可視化するところまで行うことにしました。
環境構築
システム要件
Google Colaboratyに環境を構築します。
Colab上に追加で、mmdetとmmcvのインストールが必要です。
MMDetectionのインストール
Colabの「ランタイム > ランタイムのタイプを変更」でGPUを選択しておきます。
GitHubのリポジトリから、mmdetectionをクローンしてインストールします。
!git clone https://github.com/open-mmlab/mmdetection.git
%cd /content/mmdetection
!pip install -v -e .
MMCV(MultiMedia Computing and Vision)ライブラリをインストールします。
mmcv-fullではなくmmcvをインストールします。mmcv-fullをインストールすると、mmdetのimportに失敗しました。(2024年1月時点)mmcv-fullでは、mmcvのバージョンが1.7.2となり、mmdetが要求する2.0.0以上を満たさないためです。
!pip install mmcv
データセットのダウンロード
インストールしたmmdetctionのフォルダ構成は以下のようになっています。
mmdetectionフォルダ直下にdataフォルダを作成して、データセットをダウンロードします。
mmdetection
├── mmdet
├── tools
├── configs
├── data
│ ├── laboro_tomato
│ │ ├── annotations
│ │ ├── train
│ │ ├── test
まず、データセットを保存するフォルダを作成します。
作成したdataフォルダに、Laboro TomatoのGithubからデータセットをダウンロードして、ダウンロードしたzipファイルを解凍します。
import urllib.request
import zipfile
import os
# dataフォルダを作成する
DATA_PATH = '/content/mmdetection/data'
os.mkdir(DATA_PATH)
os.chdir(DATA_PATH)
# URLを指定
url = "http://assets.laboro.ai.s3.amazonaws.com/laborotomato/laboro_tomato.zip"
save_name = url.split('/')[-1]
# ダウンロードする
mem = urllib.request.urlopen(url).read()
# ファイルへ保存
with open(save_name, mode='wb') as f:
f.write(mem)
# zipファイルをカレントディレクトリに展開する
zfile = zipfile.ZipFile(save_name)
zfile.extractall('.')
学習ずみモデルのダウンロード
mmdetectionフォルダ直下にpre_trainedmodelフォルダを作成して、作成したフォルダに学習ずみモデルをダウンロードします。
mmdetection
├── mmdet
├── tools
├── configs
├── data
├── pre_trained_model
│ ├── laboro_tomato_48ep.pth
PTH_PATH = "/content/mmdetection/pre_trained_model"
!mkdir "$PTH_PATH"
!wget -P "$PTH_PATH" http://assets.laboro.ai.s3.amazonaws.com/laborotomato/laboro_tomato_48ep.pth
設定ファイルのダウンロード
Loboro Tomato向けにカスタマイズする設定ファイルを以下のように配置します。
config_example (mmdetection)
├── configs
│ ├── _base_
│ │ ├── datasets
│ │ │ └── laboro_tomato_coco_instance.py
│ │ └── models
│ │ └── laboro_tomato_mask-rcnn_r50_fpn.py
│ └── mask_rcnn
│ └── laboro_tomato_mask-rcnn_r50_fpn_1x_coco.py
└── mmdet
└── datasets
├── __init__.py
└── laboro_tomato.py
GitHubのレポジトリから、Loboro Tomato向けの設定ファイルをダウンロードします。
%cd /content
!git clone https://github.com/laboroai/LaboroTomato.git
%cp -r /content/LaboroTomato/config_example/* /content/mmdetection/
%rm -rf /content/LaboroTomato
設定ファイルの内容については、補足説明の折りたたみを開いてください。
補足説明
laboro_tomato.py の内容
データセットcoco設定ファイルの内容でlaboro_tomato.pyを作成して、クラス名をLaboroTomatoに変更し、METAINFOパラメータを以下の例のように変更したものです。
@DATASETS.register_module()
class LaboroTomato(BaseDetDataset):
"""Dataset for COCO."""
METAINFO = {
'classes':
('b_fully_ripened', 'b_half_ripened', 'b_green', 'l_fully_ripened', 'l_half_ripened', 'l_green'),
# palette is a list of color tuples, which is used for visualization.
'palette':
[(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228), (0, 60, 100)]
}
__init__.py の内容
既存のmmdet/datasets/__init__.py にデータセット名を追加しています。
from .laboro_tomato import LaboroTomato
__all__ = [
..., 'LaboroTomato'
]
laboro_tomato_coco_instance.py の内容
cocoインスタンス設定ファイルの内容でconfigs/_base_/datasets/にlaboro_tomato_coco_instance.pyを作成し、dataset_type、data_root、dataloaderのパスを以下の例のように変更しています。
dataset_type = 'LaboroTomato'
data_root = 'data/laboro_tomato/'
...
train_dataloader = dict(
...
dataset=dict(
...
ann_file='annotations/train.json',
data_prefix=dict(img='train/'),
...
))
val_dataloader = dict(
...
dataset=dict(
...
ann_file='annotations/test.json',
data_prefix=dict(img='test/'),
...
))
test_dataloader = val_dataloader
val_evaluator = dict(
...
ann_file=data_root + 'annotations/test.json',
...
)
test_evaluator = val_evaluator
laboro_tomato_mask-rcnn_r50_fpn.py の内容
設定ファイルmask-rcnn_r50_fpn.pyの内容でconfigs/_base_/models/にlaboro_tomato_mask-rcnn_r50_fpn.pyを作成し、num_classesの設定を変更しています。
...
model = dict(
...
roi_head=dict(
...
bbox_head=dict(
...
num_classes=6,
...
),
...
mask_head=dict(
...
num_classes=6,
...
)
...
)
...
)
laboro_tomato_mask_rcnn_r50_fpn_1x_coco.py の内容
configs/mask_rcnn/にmask_rcnn_r50_fpn_1x_coco.pyの内容でlaboro_tomato_mask_rcnn_r50_fpn_1x_coco.pyを作成し、パスの設定を変更しています。
_base_ = [
'../_base_/models/laboro_tomato_mask-rcnn_r50_fpn.py',
'../_base_/datasets/laboro_tomato_coco_instance.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
物体検出を行う
init_detectorメソッドの引数に、ダウンロードしたLaboro TomatoのConfigファイルと学習ずみモデルを設定して、画像検出器を初期化します。
inference_detectorメソッドの引数に、ダウンロードしたトマト画像データセットから任意の画像を指定して、物体検出を実行します。
from mmdet.apis import init_detector, inference_detector
config_file = '/content/mmdetection/configs/mask_rcnn/laboro_tomato_mask-rcnn_r50_fpn_1x_coco.py'
checkpoint_file = '/content/mmdetection/pre_trained_model/laboro_tomato_48ep.pth'
img_path = '/content/mmdetection/data/laboro_tomato/test/IMG_0991.jpg'
model = init_detector(config_file, checkpoint_file, device='cuda:0')
result = inference_detector(model, img_path)
print(result)
resultには、検出結果がDetDataSampleで返されます。
<DetDataSample(
META INFORMATION
batch_input_shape: (1088, 800)
img_shape: (1067, 800)
img_path: '/content/mmdetection/data/laboro_tomato/test/IMG_0991.jpg'
pad_shape: (1088, 800)
ori_shape: (4032, 3024)
scale_factor: (0.26455026455026454, 0.2646329365079365)
img_id: 0
DATA FIELDS
ignored_instances: <InstanceData(
META INFORMATION
DATA FIELDS
masks: BitmapMasks(num_masks=0, height=4032, width=3024)
bboxes: tensor([], size=(0, 4))
labels: tensor([], dtype=torch.int64)
) at 0x78d42ab6dcc0>
pred_instances: <InstanceData(
META INFORMATION
DATA FIELDS
masks: tensor([[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]],
[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]],
[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]],
...,
[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]],
[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]],
[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]]])
scores: tensor([0.9985, 0.9976, 0.9961, 0.9929, 0.9903, 0.9878, 0.4669, 0.3010, 0.1244,
0.0653])
bboxes: tensor([[2502.9565, 2859.0486, 2928.9744, 3426.6133],
[1159.5984, 1825.1897, 2293.2417, 2866.6699],
[ 875.4203, 715.0433, 1974.3251, 1948.3950],
[ 921.3087, 1794.6794, 1405.1306, 2443.5981],
[2097.5969, 274.3954, 2651.9568, 867.3937],
[2234.8176, 858.9927, 2660.0400, 1254.4799],
[2235.8906, 854.2073, 2655.7715, 1255.8162],
[ 890.0048, 718.0613, 1964.9799, 1935.0585],
[2100.9795, 258.7371, 2658.7690, 885.2468],
[ 774.6519, 719.2980, 2040.8054, 2042.9907]])
labels: tensor([1, 0, 0, 0, 2, 2, 5, 1, 5, 3])
) at 0x78d54ef062c0>
gt_instances: <InstanceData(
META INFORMATION
DATA FIELDS
masks: BitmapMasks(num_masks=0, height=4032, width=3024)
bboxes: tensor([], size=(0, 4))
labels: tensor([], dtype=torch.int64)
) at 0x78d42ab6e230>
) at 0x78d42ab6dd20>
検出結果の可視化
from mmdet.registry import VISUALIZERS
import mmcv
# Visualizerを構築する
visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = model.dataset_meta
img = mmcv.imconvert(mmcv.imread(img_path), 'bgr', 'rgb')
# 予測結果を表示する
visualizer.add_datasample(
'result',
img,
data_sample=result,
draw_gt=False,
wait_time=0,
out_file=None,
pred_score_thr=0.7
)
visualizer.show()
物体が検出されて判定もされていますが、なぜだか文字が小さくて読めない・・・
関連記事