4
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

MMSegmentationを使ってみた Training編

Last updated at Posted at 2021-07-17

はじめに

からの続きで、今回はトレーニングの実行を行っていきたいと思います。

チュートリアルは、「Standord Background Dataset」というものを使っていましたが、今回はPascalVOCのデータセットを使って、「deeplabv2plus」を使ってみたいと思います。

次回は、オリジナルのデータセットを使って試してみようかと思います。

MMSegmentationのセットアップについては、前回の記事を参照してください。

Pascal VOCの準備

今回は、Pascal VOC2012のデータセットを使ってみたいと思います。
概要やダウンロードについては、以下を参照してください。
https://lib-arts.hatenablog.com/entry/dataset_ml3 

Pascal VOC2012の場合、学習用のイメージは'JPEGImages'に、
アノテーションのイメージは、'SegmentationClass'に入っています。

import os.path as osp
import numpy as np
from PIL import Image
import mmcv
import matplotlib.pyplot as plt
# convert dataset annotation to semantic segmentation map
data_root = 'xxxxx' # Pascal VOCのフォルダを指定します。
img_dir = 'JPEGImages'
ann_dir = 'SegmentationClass'

コンフィグの編集

今回は、Deeplabv3pulsを使い、PascalVOC2012を使うので以下のコンフィグを選択します。
'configs/deeplabv3plus/deeplabv3plus_r50-d8_512x512_20k_voc12aug.py'

VOC2012というものがなく、VOC2012augとなっていましたが、こちらは見つからなかったので
上記コンフィグを以下のように修正します。

_base_ = [
    '../_base_/models/deeplabv3plus_r50-d8.py',
    #'../_base_/datasets/pascal_voc12_aug.py', '../_base_/default_runtime.py',
    '../_base_/datasets/pascal_voc12.py', '../_base_/default_runtime.py',

    '../_base_/schedules/schedule_20k.py'
]
model = dict(
    decode_head=dict(num_classes=21), auxiliary_head=dict(num_classes=21))

修正後、コンフィグを読み込みます。

from mmcv import Config
cfg = Config.fromfile('configs/deeplabv3plus/deeplabv3plus_r50-d8_512x512_20k_voc12aug.py')

コンフィグの修正を行います。
チュートリアルに記載の項目を並べていますが、今回設定した箇所は、

dataset_typeを「PascalVOCDataset」、
data_root上で設定したPascalVOCをダウンロードしたパスに設定
image_dir、ann_dirも上設定したものに設定

学習済みデータから学習を始める場合は、
以下のMMSegmentationのコンフィグページから各モデルのページに移動した後、学習済みモデルへのリンクが貼られています。
https://github.com/open-mmlab/mmsegmentation/tree/master/configs

from mmseg.apis import set_random_seed

# Since we use ony one GPU, BN is used instead of SyncBN
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.model.backbone.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.norm_cfg = cfg.norm_cfg
cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg

# 分類するクラス数(独自に作成したカスタムデータセットの場合は変更してください。)
# cfg.model.decode_head.num_classes = 8
# cfg.model.auxiliary_head.num_classes = 8

# データセットの種類とパスを設定
cfg.dataset_type = 'PascalVOCDataset'
cfg.data_root = data_root

cfg.data.samples_per_gpu = 4
cfg.data.workers_per_gpu=4


cfg.data.train.type = cfg.dataset_type
cfg.data.train.data_root = cfg.data_root
cfg.data.train.img_dir = img_dir
cfg.data.train.ann_dir = ann_dir
cfg.data.train.pipeline = cfg.train_pipeline
# cfg.data.train.split = 'splits/train.txt'

cfg.data.val.type = cfg.dataset_type
cfg.data.val.data_root = cfg.data_root
cfg.data.val.img_dir = img_dir
cfg.data.val.ann_dir = ann_dir
cfg.data.val.pipeline = cfg.test_pipeline
# cfg.data.val.split = 'splits/val.txt'

cfg.data.test.type = cfg.dataset_type
cfg.data.test.data_root = cfg.data_root
cfg.data.test.img_dir = img_dir
cfg.data.test.ann_dir = ann_dir
cfg.data.test.pipeline = cfg.test_pipeline
# cfg.data.test.split = 'splits/val.txt'

# We can still use the pre-trained Mask RCNN model though we do not need to
# use the mask branch
cfg.load_from = 'https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_512x512_40k_voc12aug/deeplabv3plus_r50-d8_512x512_40k_voc12aug_20200613_161759-e1b43aa9.pth'

# Set up working dir to save files and logs.
cfg.work_dir = './work_dirs/tutorial'

# cfg.runner.max_iters = 200
# cfg.log_config.interval = 10
# cfg.evaluation.interval = 200
# cfg.checkpoint_config.interval = 200

# Set seed to facitate reproducing the result
cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)

# Let's have a look at the final config used for training
print(f'Config:\n{cfg.pretty_text}')

モデルの学習

定義したモデルを学習します。

from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.apis import train_segmentor


# Build the dataset
datasets = [build_dataset(cfg.data.train)]

# Build the detector
model = build_segmentor(
    cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg'))
# Add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES

# Create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
train_segmentor(model, datasets, cfg, distributed=False, validate=True, 
                meta=dict())

以下のような結果となりました。
訓練済みモデルを使ったので、もうちょっと精度が出るはずなんですが、今回はとりあえず動かしてみるということで話を進めます。
修正点があれば、後ほど修正します。

2021-xx-xx 19:42:45,850 - mmseg - INFO -
+-------------+-------+-------+
| Class | IoU | Acc |
+-------------+-------+-------+
| background | 81.32 | 88.28 |
| aeroplane | 46.71 | 56.09 |
| bicycle | 0.07 | 0.07 |
| bird | 21.17 | 32.36 |
| boat | 27.94 | 43.89 |
| bottle | 11.81 | 15.67 |
| bus | 59.47 | 83.44 |
| car | 49.13 | 64.26 |
| cat | 34.55 | 72.64 |
| chair | 1.33 | 1.39 |
| cow | 20.61 | 26.24 |
| diningtable | 24.23 | 39.8 |
| dog | 7.93 | 9.43 |
| horse | 22.01 | 31.72 |
| motorbike | 35.9 | 73.1 |
| person | 49.87 | 70.56 |
| pottedplant | 4.75 | 13.69 |
| sheep | 30.99 | 74.26 |
| sofa | 17.27 | 26.65 |
| train | 45.96 | 61.44 |
| tvmonitor | 18.09 | 65.33 |
+-------------+-------+-------+
2021-xx-xx 19:42:45,851 - mmseg - INFO - Summary:
2021-xx-xx 19:42:45,856 - mmseg - INFO -
+-------+------+-------+
| aAcc | mIoU | mAcc |
+-------+------+-------+
| 78.27 | 29.1 | 45.25 |

結果の確認

適当な訓練画像を使って、予測を行います。

from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot
from mmseg.core.evaluation import get_palette

img = mmcv.imread('Path/JPEGImages/2007_000033.jpg') # 画像のパスを記載

model.cfg = cfg
result = inference_segmentor(model, img)
plt.figure(figsize=(8, 6))
show_result_pyplot(model, img, result, get_palette('pascal_voc'))

image.png

対象の中では、精度がいい方である飛行機の画像を使ってみました。
他のものは、あまり認識できていなかったようです。。

4
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?