1. 概要
これまで、Semantic Segmentation modelsを用いて、航空機や衛星画像の建物のセグメンテーションや、車載画像を例に多数クラスのセマンティックセグメンテーションを紹介しました.
衛星画像のSegmentation(セグメンテーション)により建物地図を作成する.
PyTorchによるMulticlass Segmentation - 車載カメラ画像のマルチクラスセグメンテーションについて.
本記事では、Deep Learningの最新技術を用いた物体検知やセマンティックセグメーテーションを比較的簡単に実行・実装できる、OpenMMLabのMMSegmentationの使い方について紹介します。例えば、多数クラスの車載画像によるセマンティック セグメンテーションのモデルを構築することで、以下の予測結果を得ることができました.
(車載画像: Motion-based Segmentation and Recognition Dataset )
ここでは、MMSegmentationのインストール方法、デモ画像による実行、自前データ(今回はCamVidの車載画像)による学習モデルの実行と保存、学習済みモデルの読み込みと新規画像のセマンティックセグメンテーションの実行について紹介します。
MMSegmentationが初めての方を対象としたため,かなり細かく紹介しています.そのため,長文となりましたので,慣れている方はポイントだけ見てください.
ここで用いたコードはGithubにアップしましたので,ご興味のある方は試してみてください.Jupyter lab(notebook)で実行できます.ご参考になれば幸いです.
環境
本記事の実装環境は以下となります.
OS:Ubuntu: 18.04LTS
GPU:GeoForce RTX3080
Python: 3.7
Pytorch: 1.8.2
Cuda: 11.1
MMSegmentation: 0.21.0
2. セグメンテーションモデル
セグメンテーション(正確には,Semantic Segmentation)に関する記事は多数あります.例えば,以下の記事ではセグメンテーションのモデルでが紹介されています.
前回の記事でSemntaic Segmentaion modelsでの多数クラスのマルチセグメンテーションの方法を紹介しましたが,今回は多数クラスのセグメンテーションを比較的容易に試すことができる以下のモジュールを使ってみました.特に新たにモデルをつくるのは、ここで紹介しますが、Configurationを変更することで簡単に試すことができます。
このモジュールはPytorchをベースとした複数のSegmentationのモデルが準備されており,そのモデルには,Unet, FPN, PSPNet, PAN, DeepLabV3やその他最新のモデルがあります.
また,それぞれのモデルの学習済みのモデルも用意されており,転移学習やFine tuningができるため,比較的少ない学習データであっても高い精度のモデルが構築できることが期待されます.
ここでは,MMSegmentationのデモにも用いられているPSPNETによる車載画像の道路、車、人、建物などの多数クラスののセグメンテーションのモデルの構築例を紹介します.
3. MMSegmentationのインストール
MMSegmentationのモデルは、サイトのチュートリアルに例が紹介されていますが、このままではエラーとなるため、ご自身の環境にあわせてインストールするバージョンを選択する必要があります。
まず、以下を実行し、cudaとPytorchのバージョンを確認します。
# Check nvcc version
!nvcc -V
# Check GCC version
!gcc --version
# Check Pytorch installation
import torch, torchvision
print(torch.__version__, torch.cuda.is_available())
例えば、私の環境の出力結果が以下となります。
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Mon_Oct_12_20:09:46_PDT_2020
Cuda compilation tools, release 11.1, V11.1.105
Build cuda_11.1.TC455_06.29190527_0
gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Copyright (C) 2017 Free Software Foundation, Inc.
This is free software; see the source for copying conditions. There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
1.8.2 True
これより、cudaとPytorchのバージョンが、それぞれ11.1と1.8.2であることがわかりましたので、それにあわせたMMCVをはじめにインストールします。これは、MMSegmentationのバックエンドで動く基本モジュールであり、画像表示などにも使用します。
#mmcv install
!pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.8.2/index.html
その後、以下を実行しMMSegmentationをインストールします。
!rm -rf mmsegmentation
!git clone https://github.com/open-mmlab/mmsegmentation.git
%cd mmsegmentation
!pip install -e .
これで、MMSegmentationを実行する環境を構築しました。
次に、学習用画像の取得、学習とモデルの構築、学習済みモデルによる新規画像のセグメンテーションの実行と評価を順に行います。
4. 車載画像の取得方法
ここで用いる画像データは,前回と同様に以下のサイトより取得します.
Motion-based Segmentation and Recognition Dataset
ここには,車載カメラから撮像された画像と、車、歩行者、道路を含む 32 個のラベル画像が提供されています。このサイトから画像を取得し、それをベースに実行してもよいのですが、今回は以前にも紹介しましたが、Segmentaion-modelsで紹介されている方法より画像を取得します。
こちらのサイトからCamVidの車載カメラの画像とラベル画像を取得します。
import os
DATA_DIR = './data/CamVid/'
# load repo with data if it is not exists
if not os.path.exists(DATA_DIR):
print('Loading data...')
os.system('git clone https://github.com/alexgkendall/SegNet-Tutorial ./data')
print('Done!')
インストールしたMMSegmentationのモデル、および車載画像は以下のディレクトリ構成となります。mmsegmentationには他にもディレクトリがありますが、ここではこの後用いるもののみを示しています。
mmsegmentation
|
|-SegNet-Tutorial:取得した車載画
| |
| |-train
| |
| |-valannot
| |
| |-train
| |
| |-valannot
|
|-config:各種モデル
|
|-checkpoints:学習済みモデル(ダウンロード)
|
|-work_dirs:新規学習情報の保存(*この後作成します。)
5.セグメンテーションの実装による学習と評価
ここでは、MMSegmentationによる新規学習方法とその学習結果の評価・保存について紹介します。
5.1 PSPNETの学習済みモデルによるデモ実行
はじめに、MMSegmentationのチュートリアルにある学習済みモデルのダウンロードと、サンプルを実行します。
#学習済みモデルの入手
%cd mmsegmentation
!mkdir checkpoints
!wget https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth -P checkpoints
こちらを実行すると、MMSegmentationの学習済みのPSPモデルをダウンロードします。次に、以下を実行し必要なモジュールをインポートします。その他の学習済みモデルは、MMSegmentationのサイトに多数紹介されていますで、色々試してみてください。
#mmsegmentationのモジュールの呼び出し
from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot
from mmseg.core.evaluation import get_palette
そして、学習済みモデルを読み出し、モデルを設定します。
#モデルの選択、および学習済みモデルの呼び出し
config_file = 'configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py'
checkpoint_file = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'
#mmsegmentationの実行(GPUあり)
model = init_segmentor(config_file, checkpoint_file, device='cuda:0')
これでセグメンテーションの実行準備ができました。簡単です。では、デモ画像を実行します。
# test a single image
img = 'demo/demo.png'
result = inference_segmentor(model, img)
# show the results
show_result_pyplot(model, img, result, get_palette('cityscapes'))
デモ画像にセグメンテーションの処理結果のラベル画像が重畳されています。うまく、道路、人、自転車、車どが分離されているのがわかります。画像サイズは、学習モデルと同じ512x1024ですね。
では、次に今回取得した車載画像のセグメンテーションを、こちらのモデルで実行してみます。
まずは、使用するモジュールのインポートです。
#モジュールのimport
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import numpy as np
import cv2
import matplotlib.pyplot as plt
では、対象となる画像を選択し表示します。
# Let's take a look at the dataset
import mmcv
img = mmcv.imread('SegNet-Tutorial/CamVid/train/0001TP_006690.png')
print(img.shape)
plt.figure(figsize=(8, 6))
plt.imshow(mmcv.bgr2rgb(img))
plt.show()
CamVidの画像を表示しました。比較的暗めの画像ですね。では、学習済みモデルをもちいたセグメンテーションを実行します。
# test a single image
img = 'SegNet-Tutorial/CamVid/train/0001TP_006690.png'
result = inference_segmentor(model, img)
show_result_pyplot(model, img, result, get_palette('cityscapes'))
学習済みモデルの画像に近い画像なので、それなりにはクラス分けされたセグメンテーション画像にはなっていますが、こののままでは使えないですね。例えば、この車載画像のラベル画像はこちらとなります。
# Let's take a look at the segmentation map we got
import matplotlib.patches as mpatches
img = Image.open('SegNet-Tutorial/CamVid/trainannot/0001TP_006690.png')
plt.figure(figsize=(8, 6))
img_p = np.array(img.convert('P'))
plt.imshow(np.array(img.convert('P'))) #'RGB'
plt.show()
セグメンテーションによる予測結果と比べて、精度がそれほど高くなく、画像をうまくクラス分け(セグメンテーション)できていないのがわかります。では、車載画像による学習モデルを構築します。
5.2 新規学習モデルの構築
はじめに、今回用いる学習画像を準備します。
import os.path as osp
import numpy as np
from PIL import Image
# convert dataset annotation to semantic segmentation map
data_root = 'SegNet-Tutorial/CamVid'
img_dir = 'train'
ann_dir = 'trainannot'
# define class and plaette for better visualization
classes = ('sky', 'Bulding', 'Pole', 'Road_marking', 'Road', 'Pavement', 'Tree', 'SingSymbole','Fence', 'Car', 'Pedestrian', 'Bicyclist', 'Unlabeled')
palette = [[128,128,128], [128,0,0], [192,192,128], [255,69,0], [128,64,128], [60,40,222], [128,128,0], [192,128,128], [64,64,128], [64,0,128], [64,64,0], [0,128,192], [0,0,0]]
車載画像は8クラスあるため、上記のように分類します。また、セグメンテーションの実行結果と画像とを重畳するときにカラー(palette)もここで設定します。
では、学習と評価に用いる画像を設定します。
# split train/val set randomly
split_dir = 'splits_resnet50A'
mmcv.mkdir_or_exist(osp.join(data_root, split_dir))
filename_list = [osp.splitext(filename)[0] for filename in mmcv.scandir(
osp.join(data_root, ann_dir), suffix='.png')]
with open(osp.join(data_root, split_dir, 'train.txt'), 'w') as f:
# select first 4/5 as train set
train_length = int(len(filename_list)*4/5)
f.writelines(line + '\n' for line in filename_list[:train_length])
with open(osp.join(data_root, split_dir, 'val.txt'), 'w') as f:
# select last 1/5 as train set
f.writelines(line + '\n' for line in filename_list[train_length:])
今回はTrainの画像のみを用いました。全体のうち、4/5を学習用(train)に、残りの1/5を評価用(val)に用います。ここでは、学習および評価用のファイル名をリストにして、'splits_resnet50A'ディレクトリに保存します。もし、Val画像を評価用に用いたい場合は、splitディレクトリにそれぞれのファイルリストを設定してください。
次に、Pytorchによる画像解析の実行と同じく、MMSegmentationのフォーマットにあわせたDatasetを作成します。
from mmseg.datasets.builder import DATASETS
from mmseg.datasets.custom import CustomDataset
@DATASETS.register_module()
class splits_resnet50A(CustomDataset):
CLASSES = classes
PALETTE = palette
def __init__(self, split, **kwargs):
super().__init__(img_suffix='.png', seg_map_suffix='.png',
split=split, **kwargs)
assert osp.exists(self.img_dir) and self.split is not None
画像とアノテーション画像のフォーマット(拡張子)や、Datasetの名前(ここでは、splits_resnet50A)を設定します。
次に、学習モデルを設定します。
from mmcv import Config
cfg = Config.fromfile('configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py')
いちから自分で構築することもできますが、学習済みのモデルを用いるほうが間違いがなく近道と思います。また、最新のモデルがサイトに日々更新されていますので、どのような構成なのか、こちらで呼び出し試すこともできます。また、モデルによって精度がどれだか変わるのか試すこともできますので、設定されているモデルを用いることをオススメします。ここでは、対象画像がcityscapesのデータ・セットと同じ構成ですので、こちらを対象としました。どういうモデルか気になる方は、以下を実行すると構成を確認できます。
print(f'Config:\n{cfg.pretty_text}')
次に、呼び出したモデルを、今回の画像の仕様にあわせて個々に設定します。
from mmseg.apis import set_random_seed
# add CLASSES and PALETTE to checkpoint 1)
cfg.checkpoint_config.meta = dict(
CLASSES=classes,
PALETTE=palette)
# Since we use ony one GPU, BN is used instead of SyncBN
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.model.backbone.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.norm_cfg = cfg.norm_cfg
cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg
# modify num classes of the model in decode/auxiliary head
cfg.model.decode_head.num_classes = len(classes) # 2)
cfg.model.auxiliary_head.num_classes = len(classes) # 2)
# Modify dataset type and path
cfg.dataset_type = 'splits_resnet50A' #3)
cfg.data_root = data_root
cfg.data.samples_per_gpu = 8
cfg.data.workers_per_gpu=8
cfg.img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
cfg.crop_size = (256, 256)
cfg.train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(w, h), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=cfg.crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **cfg.img_norm_cfg),
dict(type='Pad', size=cfg.crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
cfg.test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(w, h), #6)
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **cfg.img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
cfg.data.train.type = cfg.dataset_type
cfg.data.train.data_root = cfg.data_root
cfg.data.train.img_dir = img_dir
cfg.data.train.ann_dir = ann_dir
cfg.data.train.pipeline = cfg.train_pipeline
cfg.data.train.split = 'splits_resnet50A/train.txt' #3)
cfg.data.val.type = cfg.dataset_type
cfg.data.val.data_root = cfg.data_root
cfg.data.val.img_dir = img_dir
cfg.data.val.ann_dir = ann_dir
cfg.data.val.pipeline = cfg.test_pipeline
cfg.data.val.split = 'splits_resnet50A/val.txt' #3)
cfg.data.test.type = cfg.dataset_type
cfg.data.test.data_root = cfg.data_root
cfg.data.test.img_dir = img_dir
cfg.data.test.ann_dir = ann_dir
cfg.data.test.pipeline = cfg.test_pipeline
cfg.data.test.split = 'splits_resnet50A/val.txt' #3)
# We can still use the pre-trained Mask RCNN model though we do not need to
# use the mask branch
#cfg.load_from = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth' #5)
# Set up working dir to save files and logs.
cfg.work_dir = './work_dirs/tutorial_pspnet_r50A' #4)
cfg.runner.max_iters = 40000
cfg.log_config.interval = 10
cfg.evaluation.interval = 200
cfg.checkpoint_config.interval = 1000
# Set seed to facitate reproducing the result
cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)
# Let's have a look at the final config used for training
print(f'Config:\n{cfg.pretty_text}')
チュートリアルからいくつか変更しています。
1)学習モデルの保存・呼び出し時にClassとPaletteを呼び出すため、こちらの条件を追加で設定。
2)使用する画像のクラス数に合わせて変更
3)画像、アノテーション画像のディレクトリに変更
4)学習過程のログやモデルの保存先を変更
5)今回は、学習済みモデルを用いないため、コメントアウト
6) 画像サイズに合わせて変更
学習モデルやDatasetについては、チュートリアルサイトで詳しく紹介されていますので、こちらもご参考にください。
こちらを実行すると、構築した学習モデルの構成が表示されます。こちらが学習条件になりますので、ご確認ください。
では、学習を実行します。
from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.apis import train_segmentor
# Build the dataset
datasets = [build_dataset(cfg.data.train)]
# Build the detector
model = build_segmentor(
cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg'))
# Add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES
# Create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
train_segmentor(model, datasets, cfg, distributed=False, validate=True,
meta=dict())
Configurationでevaluation.intervalで設定したエポック数毎に評価が実行されます。例えば、最初200エポック時の評価結果が以下となります。
+--------------+-------+-------+
| Class | IoU | Acc |
+--------------+-------+-------+
| sky | 80.57 | 88.45 |
| Bulding | 51.52 | 88.17 |
| Pole | 0.0 | 0.0 |
| Road_marking | 87.36 | 97.79 |
| Road | 9.7 | 9.93 |
| Pavement | 6.5 | 6.8 |
| Tree | 0.0 | 0.0 |
| SingSymbole | 0.0 | 0.0 |
| Fence | 43.88 | 65.66 |
| Car | 0.0 | 0.0 |
| Pedestrian | 0.0 | 0.0 |
| Bicyclist | 7.28 | 9.24 |
| Unlabeled | nan | nan |
+--------------+-------+-------+
+-------+------+------+
| aAcc | mIoU | mAcc |
+-------+------+------+
| 73.18 | 23.9 | 30.5 |
各クラスのIoUと全体のmIoUを確認できます。これからですね。では、最終40000エポック後の評価結果をみてみます。
+--------------+-------+-------+
| Class | IoU | Acc |
+--------------+-------+-------+
| sky | 91.13 | 96.07 |
| Bulding | 90.62 | 96.37 |
| Pole | 24.29 | 30.84 |
| Road_marking | 98.12 | 99.03 |
| Road | 85.33 | 90.89 |
| Pavement | 81.3 | 91.49 |
| Tree | 67.06 | 77.0 |
| SingSymbole | 73.13 | 82.6 |
| Fence | 91.65 | 96.01 |
| Car | 53.04 | 64.28 |
| Pedestrian | 64.97 | 87.69 |
| Bicyclist | 57.39 | 65.76 |
| Unlabeled | nan | nan |
+--------------+-------+-------+
+-------+-------+------+
| aAcc | mIoU | mAcc |
+-------+-------+------+
| 94.37 | 73.17 | 81.5 |
+-------+-------+------+
mIoUが73%とかなり高い精度であることがわかります。では、学習によってモデルの性能がどのように向上したのか、グラフにしてみてみます。まずは、jsonフォーマットのlogファイルを読み込みます。
import json
log_file = './work_dirs/tutorial_pspnet_r50A/None.log.json'
res = []
decoder = json.JSONDecoder()
with open(log_file, 'r') as f:
line = f.readline()
while line:
res.append(decoder.raw_decode(line))
line = f.readline()
読み込んだファイルから、学習数に対する各性能をリストにします。
x_epoch_data = []
aux_loss = []
aux_acc = []
loss = []
x_epoch_data_2 = []
mIou = []
mAcc = []
for i in range(len(res)-1):
if 'aux.loss_ce' in res[i+1][0]:
x_epoch_data.append(res[i +1][0]['iter'])
#print(res[i +1][0]['aux.loss_ce'])
aux_loss.append(res[i +1][0]['aux.loss_ce'])
aux_acc.append(res[i +1][0]['aux.acc_seg'])
loss.append(res[i +1][0]['loss'])
elif 'mIoU' in res[i+1][0]:
#print(i)
x_epoch_data_2.append(res[i][0]['iter']) # iter number before evaluation
mIou.append(res[i +1][0]['mIoU'])
mAcc.append(res[i +1][0]['mAcc'])
else:
pass
では、グラフ化します。
fig = plt.figure(figsize=(14, 5))
ax1 = fig.add_subplot(1, 2, 1)
ax2 = ax1.twinx()
line1, = ax2.plot(x_epoch_data,aux_loss,label='aux_loss',color='red')
line2, = ax1.plot(x_epoch_data,aux_acc,label='aux_acc',color='blue')
line3, = ax2.plot(x_epoch_data,loss,label='loss',color='green')
ax1.set_title("loss/acc loss")
ax1.set_xlabel('epoch')
ax1.set_ylabel('aux-acc')
ax2.set_ylabel('loss')
ax1.set_ylim(0, 100)
ax1.legend(loc='upper right')
ax2.legend(loc='lower right')
ax3 = fig.add_subplot(1, 2, 2)
line1, = ax3.plot(x_epoch_data_2,mIou,label='mIoU',color='blue')
line2, = ax3.plot(x_epoch_data_2,mAcc,label='mAcc',color='green')
ax3.set_title("mIoU/mAcc score")
ax3.set_xlabel('epoch')
ax3.set_ylabel('mIoU/mAcc_score')
ax3.set_ylim(0, 1)
ax3.legend(loc='upper left')
plt.show()
学習することでモデルの性能が向上しているのがわかります。今回は、40000エポックで終了しましたが、この傾向からこれ以上学習回数を増やしても性能はほとんど変わらないことがわかりました。
最後に今回の学習の条件(Configration)を保存します。これは、MMSegmentationのサイトでの説明がありません。はじめは、学習時に読み出したモデルを使えばよいと考えたのですが、呼び出してもうまくモデルを回すことができませんでした。読み出したモデルと今回の条件とでは異なりますので、当たり前ですね。ご注意ください。
#save config
cfg_file = cfg.pretty_text
cfg_path = './work_dirs/tutorial_pspnet_r50A/cfg.py'
with open(cfg_path, mode='w') as f:
f.write(cfg_file)
これで、学習モデルの構築は終了です。
##5.3 学習済みモデルの呼び出し、および新規画像によるセグメンテーションの実行
はじめに、学習済みモデルを以下を実行し読み出します。
#mmsegmentationの実行のロード(GPUあり)
config_file = cfg_path #configファイルのパス
checkpoint_file = 'work_dirs/tutorial_pspnet_r50A/latest.pth' #学習済みの最終モデル
model = init_segmentor(config_file, checkpoint_file, device='cuda:0')
これでモデルの準備ができました。 では、学習済みモデルを用いた、車載画像のセグメンテーションを実行します。
img = mmcv.imread('SegNet-Tutorial/CamVid/train/0001TP_006690.png')
model.cfg = cfg
result = inference_segmentor(model, img)
plt.figure(figsize=(8, 6))
show_result_pyplot(model, img, result, palette)
最初のPSPNETの学習モデルによる予測結果と比べて、かなりうまくセグメンテーション(クラス分類)できているのがわかります。では、他の画像はどうなのか、確認してみます。まず、それぞれの画像を並べて表示する関数を準備します。
# 車載画像、アノテーション画像、予測画像を並べて表示する
def visualize(**images):
"""PLot images in one row."""
n = len(images)
plt.figure(figsize=(16, 5))
for i, (name, image) in enumerate(images.items()):
plt.subplot(1, n, i + 1)
plt.xticks([])
plt.yticks([])
plt.title(' '.join(name.split('_')).title())
plt.imshow(image)
plt.show()
では、実行します。今回は、Train画像からランダムに4枚選びました。実行のたびに新しいセグメンテーション結果が表示されますので、試してみてください。
file_list = os.listdir('SegNet-Tutorial/CamVid/train')
print('the number of the image: ', len(file_list))
for i in range(4):
n = np.random.choice(len(file_list))
image2 = mmcv.imread(os.path.join('SegNet-Tutorial/CamVid/train', file_list[n]))
result = inference_segmentor(model, image2)
pred_mask2 = np.array(result).transpose(1, 2, 0)
mask2 = Image.open(os.path.join('SegNet-Tutorial/CamVid/trainannot', file_list[n]))
visualize(
image=image2,
ground_truth=mask2,
predict_mask = result[0],
)
学習に用いた画像が300枚程度とかなり少ないのですが、うまくセグメンテーションができていますね。
6. まとめ
Pytorchを用いたMMSegmentationによる多数クラス画像のセマンティックセグメンテーションの方法を紹介しました。
MMSegmentationは、PytorchのDatasetと異なる方法でDatasetを作成するため、はじめは戸惑いましたが、慣れると簡単に作成できます。また、有名なデータ・セットの各モデルの学習済みモデルも容易されていますので、それをダウンロードした転移学習で高精度なモデルを構築することができます。また、セグメンテーションの特徴として、小さな構造物(今回でいえば、人やポールなど)は低いIoUとなることがおおいため、各クラスのWeightを個々に設定することで、その配分を変えて小さなものも検知しやくすることもできます。これらの方法については、次回紹介します。
私は衛星画像のセグメンテーションに関心をもっており、その画像から建物や道路などの構造物や、農地・森林などの土地情報などの多数クラスのセグメンテーションがこちらをもいることで可能です。これにより、衛星画像から最新の地図を作成することができます。例えば,無料で提供されている欧州宇宙機関(ESA)のSentinel-2は同一地点を5日毎に観測しており,被雲率を考慮し月に1回利用したとしても,関心地域のトレンド(例えば街の発展具合など)を把握することができます.また,災害前後の画像を比較すると,建物などの構造物の被害がフォーカスされることで,その被害規模を把握することに利用されることが期待されます.
次回は,衛星画像を対象に、セグメンテーションの実験を行い、精度向上の工夫とその評価を紹介します。
長文記事を最後までご覧いただきありがとうございました.私はこの分野の専門ではないため,間違って解釈しているところがあるかと思います.ご指摘いただければ幸いです.
また,コメント等ありましたらいただけると嬉しいです.
航空機や人工衛星の画像に関心を持つ方が多くなり,その応用の発展に寄与できれば幸いです.
7. 参考記事
衛星画像のSegmentation(セグメンテーション)により建物地図を作成する.
PyTorchによるMulticlass Segmentation - 車載カメラ画像のマルチクラスセグメンテーションについて..
セマンティックセグメンテーションをざっくり学ぶ
MMSegmentation
Motion-based Segmentation and Recognition Dataset
人工衛星(Sentinel-2)の観測画像をAPIを使って自動取得してみた.