1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

LaboroTomato画像物体検出

Posted at

はじめに

機械学習モデルを構築するところから学習する目的で、農業をテーマとする機械学習の教師あり分類で使えそうなデータセットを探していたところ、Laboro Tomatoというトマト画像物体検出データセットを見つけました。
しかし、MMDetectionのツールボックス上に学習済みモデルが構築済みで、その目的を達成することができないと分かりました。
そこで方向性を変えて、学習済みモデルを利用して画像検出結果を可視化するところまで行うことにしました。

環境構築

システム要件

Google Colaboratyに環境を構築します。
Colab上に追加で、mmdetとmmcvのインストールが必要です。

MMDetectionのインストール

Colabの「ランタイム > ランタイムのタイプを変更」でGPUを選択しておきます。

GitHubのリポジトリから、mmdetectionをクローンしてインストールします。

!git clone https://github.com/open-mmlab/mmdetection.git
%cd /content/mmdetection
!pip install -v -e .

MMCV(MultiMedia Computing and Vision)ライブラリをインストールします。

mmcv-fullではなくmmcvをインストールします。mmcv-fullをインストールすると、mmdetのimportに失敗しました。(2024年1月時点)mmcv-fullでは、mmcvのバージョンが1.7.2となり、mmdetが要求する2.0.0以上を満たさないためです。

!pip install mmcv

データセットのダウンロード

インストールしたmmdetctionのフォルダ構成は以下のようになっています。
mmdetectionフォルダ直下にdataフォルダを作成して、データセットをダウンロードします。

mmdetection
├── mmdet
├── tools
├── configs
├── data
│   ├── laboro_tomato
│   │   ├── annotations
│   │   ├── train
│   │   ├── test

まず、データセットを保存するフォルダを作成します。
作成したdataフォルダに、Laboro TomatoのGithubからデータセットをダウンロードして、ダウンロードしたzipファイルを解凍します。

import urllib.request
import zipfile
import os

# dataフォルダを作成する
DATA_PATH = '/content/mmdetection/data'
os.mkdir(DATA_PATH)
os.chdir(DATA_PATH)

# URLを指定
url = "http://assets.laboro.ai.s3.amazonaws.com/laborotomato/laboro_tomato.zip"
save_name = url.split('/')[-1]

# ダウンロードする
mem = urllib.request.urlopen(url).read()

# ファイルへ保存
with open(save_name, mode='wb') as f:
    f.write(mem)

# zipファイルをカレントディレクトリに展開する
zfile = zipfile.ZipFile(save_name)
zfile.extractall('.')

学習ずみモデルのダウンロード

mmdetectionフォルダ直下にpre_trainedmodelフォルダを作成して、作成したフォルダに学習ずみモデルをダウンロードします。

mmdetection
├── mmdet
├── tools
├── configs
├── data
├── pre_trained_model
│   ├── laboro_tomato_48ep.pth 
PTH_PATH = "/content/mmdetection/pre_trained_model"
!mkdir "$PTH_PATH"
!wget -P "$PTH_PATH" http://assets.laboro.ai.s3.amazonaws.com/laborotomato/laboro_tomato_48ep.pth

設定ファイルのダウンロード

Loboro Tomato向けにカスタマイズする設定ファイルを以下のように配置します。

config_example (mmdetection)
├── configs
│   ├── _base_
│   │   ├── datasets
│   │   │   └── laboro_tomato_coco_instance.py
│   │   └── models
│   │       └── laboro_tomato_mask-rcnn_r50_fpn.py
│   └── mask_rcnn
│       └── laboro_tomato_mask-rcnn_r50_fpn_1x_coco.py
└── mmdet
    └── datasets
        ├── __init__.py
        └── laboro_tomato.py

GitHubのレポジトリから、Loboro Tomato向けの設定ファイルをダウンロードします。

%cd /content
!git clone https://github.com/laboroai/LaboroTomato.git
%cp -r /content/LaboroTomato/config_example/* /content/mmdetection/
%rm -rf /content/LaboroTomato

設定ファイルの内容については、補足説明の折りたたみを開いてください。

補足説明

laboro_tomato.py の内容

データセットcoco設定ファイルの内容でlaboro_tomato.pyを作成して、クラス名をLaboroTomatoに変更し、METAINFOパラメータを以下の例のように変更したものです。


@DATASETS.register_module()
class LaboroTomato(BaseDetDataset):
    """Dataset for COCO."""

    METAINFO = {
        'classes':
        ('b_fully_ripened', 'b_half_ripened', 'b_green', 'l_fully_ripened', 'l_half_ripened', 'l_green'),
        # palette is a list of color tuples, which is used for visualization.
        'palette':
        [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228), (0, 60, 100)]
    }

__init__.py の内容

既存のmmdet/datasets/__init__.py にデータセット名を追加しています。

from .laboro_tomato import LaboroTomato

__all__ = [    
           ..., 'LaboroTomato'
          ]

laboro_tomato_coco_instance.py の内容

cocoインスタンス設定ファイルの内容でconfigs/_base_/datasets/にlaboro_tomato_coco_instance.pyを作成し、dataset_type、data_root、dataloaderのパスを以下の例のように変更しています。

dataset_type = 'LaboroTomato'
data_root = 'data/laboro_tomato/'
...
train_dataloader = dict(
    ...
    dataset=dict(
        ...
        ann_file='annotations/train.json',
        data_prefix=dict(img='train/'),
        ...
    ))
val_dataloader = dict(
    ...
    dataset=dict(
        ...
        ann_file='annotations/test.json',
        data_prefix=dict(img='test/'),
        ...
    ))
test_dataloader = val_dataloader

val_evaluator = dict(
    ...
    ann_file=data_root + 'annotations/test.json',
    ...
    )
test_evaluator = val_evaluator

laboro_tomato_mask-rcnn_r50_fpn.py の内容

設定ファイルmask-rcnn_r50_fpn.pyの内容でconfigs/_base_/models/にlaboro_tomato_mask-rcnn_r50_fpn.pyを作成し、num_classesの設定を変更しています。

...
model = dict(
    ...
    roi_head=dict(
        ...
        bbox_head=dict(
            ...
            num_classes=6,
            ...
        ),
        ...
        mask_head=dict(
            ...
            num_classes=6,
            ...
        )
        ...
    )
    ...
)

laboro_tomato_mask_rcnn_r50_fpn_1x_coco.py の内容

configs/mask_rcnn/にmask_rcnn_r50_fpn_1x_coco.pyの内容でlaboro_tomato_mask_rcnn_r50_fpn_1x_coco.pyを作成し、パスの設定を変更しています。

_base_ = [
    '../_base_/models/laboro_tomato_mask-rcnn_r50_fpn.py',
    '../_base_/datasets/laboro_tomato_coco_instance.py',
    '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]

物体検出を行う

init_detectorメソッドの引数に、ダウンロードしたLaboro TomatoのConfigファイルと学習ずみモデルを設定して、画像検出器を初期化します。
inference_detectorメソッドの引数に、ダウンロードしたトマト画像データセットから任意の画像を指定して、物体検出を実行します。

from mmdet.apis import init_detector, inference_detector

config_file = '/content/mmdetection/configs/mask_rcnn/laboro_tomato_mask-rcnn_r50_fpn_1x_coco.py'
checkpoint_file = '/content/mmdetection/pre_trained_model/laboro_tomato_48ep.pth'
img_path = '/content/mmdetection/data/laboro_tomato/test/IMG_0991.jpg'

model = init_detector(config_file, checkpoint_file, device='cuda:0')
result = inference_detector(model, img_path)
print(result)

resultには、検出結果がDetDataSampleで返されます。

<DetDataSample(

    META INFORMATION
    batch_input_shape: (1088, 800)
    img_shape: (1067, 800)
    img_path: '/content/mmdetection/data/laboro_tomato/test/IMG_0991.jpg'
    pad_shape: (1088, 800)
    ori_shape: (4032, 3024)
    scale_factor: (0.26455026455026454, 0.2646329365079365)
    img_id: 0

    DATA FIELDS
    ignored_instances: <InstanceData(
        
            META INFORMATION
        
            DATA FIELDS
            masks: BitmapMasks(num_masks=0, height=4032, width=3024)
            bboxes: tensor([], size=(0, 4))
            labels: tensor([], dtype=torch.int64)
        ) at 0x78d42ab6dcc0>
    pred_instances: <InstanceData(
        
            META INFORMATION
        
            DATA FIELDS
            masks: tensor([[[False, False, False,  ..., False, False, False],
                         [False, False, False,  ..., False, False, False],
                         [False, False, False,  ..., False, False, False],
                         ...,
                         [False, False, False,  ..., False, False, False],
                         [False, False, False,  ..., False, False, False],
                         [False, False, False,  ..., False, False, False]],
                
                        [[False, False, False,  ..., False, False, False],
                         [False, False, False,  ..., False, False, False],
                         [False, False, False,  ..., False, False, False],
                         ...,
                         [False, False, False,  ..., False, False, False],
                         [False, False, False,  ..., False, False, False],
                         [False, False, False,  ..., False, False, False]],
                
                        [[False, False, False,  ..., False, False, False],
                         [False, False, False,  ..., False, False, False],
                         [False, False, False,  ..., False, False, False],
                         ...,
                         [False, False, False,  ..., False, False, False],
                         [False, False, False,  ..., False, False, False],
                         [False, False, False,  ..., False, False, False]],
                
                        ...,
                
                        [[False, False, False,  ..., False, False, False],
                         [False, False, False,  ..., False, False, False],
                         [False, False, False,  ..., False, False, False],
                         ...,
                         [False, False, False,  ..., False, False, False],
                         [False, False, False,  ..., False, False, False],
                         [False, False, False,  ..., False, False, False]],
                
                        [[False, False, False,  ..., False, False, False],
                         [False, False, False,  ..., False, False, False],
                         [False, False, False,  ..., False, False, False],
                         ...,
                         [False, False, False,  ..., False, False, False],
                         [False, False, False,  ..., False, False, False],
                         [False, False, False,  ..., False, False, False]],
                
                        [[False, False, False,  ..., False, False, False],
                         [False, False, False,  ..., False, False, False],
                         [False, False, False,  ..., False, False, False],
                         ...,
                         [False, False, False,  ..., False, False, False],
                         [False, False, False,  ..., False, False, False],
                         [False, False, False,  ..., False, False, False]]])
            scores: tensor([0.9985, 0.9976, 0.9961, 0.9929, 0.9903, 0.9878, 0.4669, 0.3010, 0.1244,
                        0.0653])
            bboxes: tensor([[2502.9565, 2859.0486, 2928.9744, 3426.6133],
                        [1159.5984, 1825.1897, 2293.2417, 2866.6699],
                        [ 875.4203,  715.0433, 1974.3251, 1948.3950],
                        [ 921.3087, 1794.6794, 1405.1306, 2443.5981],
                        [2097.5969,  274.3954, 2651.9568,  867.3937],
                        [2234.8176,  858.9927, 2660.0400, 1254.4799],
                        [2235.8906,  854.2073, 2655.7715, 1255.8162],
                        [ 890.0048,  718.0613, 1964.9799, 1935.0585],
                        [2100.9795,  258.7371, 2658.7690,  885.2468],
                        [ 774.6519,  719.2980, 2040.8054, 2042.9907]])
            labels: tensor([1, 0, 0, 0, 2, 2, 5, 1, 5, 3])
        ) at 0x78d54ef062c0>
    gt_instances: <InstanceData(
        
            META INFORMATION
        
            DATA FIELDS
            masks: BitmapMasks(num_masks=0, height=4032, width=3024)
            bboxes: tensor([], size=(0, 4))
            labels: tensor([], dtype=torch.int64)
        ) at 0x78d42ab6e230>
) at 0x78d42ab6dd20>

検出結果の可視化

from mmdet.registry import VISUALIZERS
import mmcv

# Visualizerを構築する
visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = model.dataset_meta

img = mmcv.imconvert(mmcv.imread(img_path), 'bgr', 'rgb')

# 予測結果を表示する
visualizer.add_datasample(
    'result',
    img,
    data_sample=result,
    draw_gt=False,
    wait_time=0,
    out_file=None,
    pred_score_thr=0.7
)

visualizer.show()

物体が検出されて判定もされていますが、なぜだか文字が小さくて読めない・・・
result.png

関連記事

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?