5
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

MMSegmentationを使ってみた オリジナルデータトレーニング編

Posted at

はじめに

からの続きで、オリジナルのデータを使った場合のMMSegmentationの利用方法について、記載したいと思います。

対象のデータ

用意するデータは、画像及びアノテーション画像でPascal VOCに準じた形式を利用するものとします。

アノテーションデータについては、RGBの画像データではなく、0,1,2,...のカラーパレットのインデックスのデータを用意したものとします。

データの作成方法については、以下の記事などを参考にしてください。

データセットの設定

データセットのパスなどを指定します。

import os.path as osp
import numpy as np
from PIL import Image
import mmcv
import matplotlib.pyplot as plt
# convert dataset annotation to semantic segmentation map
data_root = データを格納しているルートフォルダのパス
img_dir = 学習用の画像のフォルダ名
ann_dir = アノテーション画像のフォルダ名

画像クラスの設定と色の指定
今回は、背景と建物の2値分類として、建物を赤で表示するようにします。

classes = ('background','builing')
palette = [[0,0,0],[255, 0, 0]]

訓練データと検証データの分割

# split train/val set randomly
split_dir = 'splits'
mmcv.mkdir_or_exist(osp.join(data_root, split_dir))
filename_list = [osp.splitext(filename)[0] for filename in mmcv.scandir(
    osp.join(data_root, ann_dir), suffix='.png')]
with open(osp.join(data_root, split_dir, 'train.txt'), 'w') as f:
  # select first 4/5 as train set
  train_length = int(len(filename_list)*4/5)
  f.writelines(line + '\n' for line in filename_list[:train_length])
with open(osp.join(data_root, split_dir, 'val.txt'), 'w') as f:
  # select last 1/5 as train set
  f.writelines(line + '\n' for line in filename_list[train_length:])

オリジナルデータセットの作成

from mmseg.datasets.builder import DATASETS
from mmseg.datasets.custom import CustomDataset

@DATASETS.register_module()
class OriginalDataset(CustomDataset):
    CLASSES = classes
    PALETTE = palette
    def __init__(self, split, **kwargs):
        super().__init__(img_suffix='.png', seg_map_suffix='.png', 
                     split=split, **kwargs)
        assert osp.exists(self.img_dir) and self.split is not None

コンフィグファイルの編集

前回と同様でdeeplabv3plusのコンフィグを利用することにします。

コンフィグファイルの読み込み

from mmcv import Config
cfg = Config.fromfile('configs/deeplabv3plus/deeplabv3plus_r50-d8_512x512_20k_voc12aug.py')

コンフィグの編集
上記で設定した分類するクラスの数、対象のデータセット名、
訓練データのパスなどを編集します。

from mmseg.apis import set_random_seed

# Since we use ony one GPU, BN is used instead of SyncBN
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.model.backbone.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.norm_cfg = cfg.norm_cfg
cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg

# 分類するクラス数
cfg.model.decode_head.num_classes = 2
cfg.model.auxiliary_head.num_classes = 2

# Modify dataset type and path
cfg.dataset_type = 'OriginalDataset'
cfg.data_root = data_root

cfg.data.samples_per_gpu = 4
cfg.data.workers_per_gpu=4


cfg.data.train.type = cfg.dataset_type
cfg.data.train.data_root = cfg.data_root
cfg.data.train.img_dir = img_dir
cfg.data.train.ann_dir = ann_dir
cfg.data.train.pipeline = cfg.train_pipeline
cfg.data.train.split = 'splits/train.txt'

cfg.data.val.type = cfg.dataset_type
cfg.data.val.data_root = cfg.data_root
cfg.data.val.img_dir = img_dir
cfg.data.val.ann_dir = ann_dir
cfg.data.val.pipeline = cfg.test_pipeline
cfg.data.val.split = 'splits/val.txt'

cfg.data.test.type = cfg.dataset_type
cfg.data.test.data_root = cfg.data_root
cfg.data.test.img_dir = img_dir
cfg.data.test.ann_dir = ann_dir
cfg.data.test.pipeline = cfg.test_pipeline
cfg.data.test.split = 'splits/val.txt'

# We can still use the pre-trained Mask RCNN model though we do not need to
# use the mask branch
# cfg.load_from = 'https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_512x512_40k_voc12aug/deeplabv3plus_r50-d8_512x512_40k_voc12aug_20200613_161759-e1b43aa9.pth'

# Set up working dir to save files and logs.
cfg.work_dir = './work_dirs/tutorial'

# cfg.runner.max_iters = 200
# cfg.log_config.interval = 10
# cfg.evaluation.interval = 200
# cfg.checkpoint_config.interval = 200

# Set seed to facitate reproducing the result
cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)

# Let's have a look at the final config used for training
print(f'Config:\n{cfg.pretty_text}')

訓練の実施

前回と同様です。
公開可能なオリジナルデータを所持していないので、コードのみとなります。

from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.apis import train_segmentor


# Build the dataset
datasets = [build_dataset(cfg.data.train)]

# Build the detector
model = build_segmentor(
    cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg'))
# Add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES

# Create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
train_segmentor(model, datasets, cfg, distributed=False, validate=True, 
                meta=dict())

結果の表示

こちらも前回と同様となります。

from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot
from mmseg.core.evaluation import get_palette

img = mmcv.imread(表示したいファイルのパス)

model.cfg = cfg
result = inference_segmentor(model, img)
plt.figure(figsize=(8, 6))
show_result_pyplot(model, img, result, palette)

これで一通り、MMSegmentaionの利用の仕方が理解できました。
MMdetectionと合わせて、画像検知系のタスクは利用できるようになったかと思いますし、
最新モデルもすぐに反映されるように思いますので、比較的すぐ最新モデルを試すことができると思います。

5
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?