概要
- 前提
- コード
前提
- configファイルがある
- 学習済みモデル(チェックポイントファイル)がある
- 動画ファイルがある
- 検出するクラスが1つのみ(多クラスはこのコードを改造すればいけるかも...)
コード
以下のコードを参考にした.
- https://github.com/open-mmlab/mmdetection/blob/main/demo/video_demo.py
- https://github.com/open-mmlab/mmdetection/blob/main/mmdet/visualization/local_visualizer.py
- https://github.com/open-mmlab/mmengine/blob/a5f48f7d99ae250b329272f88b961af1d3ebcf1e/mmengine/visualization/visualizer.py
createMaskVideo.py
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import cv2
import mmcv
from mmcv.transforms import Compose
from mmengine.utils import track_iter_progress
from mmdet.apis import inference_detector, init_detector
from mmdet.registry import VISUALIZERS
import numpy as np
import torch
from mmdet.structures.mask import BitmapMasks, PolygonMasks
from mmengine.visualization.utils import tensor2ndarray
def parse_args():
parser = argparse.ArgumentParser(description='MMDetection video demo')
parser.add_argument('video', help='Video file')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='Bbox score threshold')
parser.add_argument('--out', type=str, help='Output video file')
args = parser.parse_args()
return args
def main():
args = parse_args()
assert args.out, \
('Please specify at least one operation (save/show the '
'video) with the argument "--out"')
# build the model from a config file and a checkpoint file
model = init_detector(args.config, args.checkpoint, device=args.device)
# build test pipeline
model.cfg.test_dataloader.dataset.pipeline[0].type = 'LoadImageFromNDArray'
test_pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)
# 動画の読み込みと書き込みの設定をする
video_reader = mmcv.VideoReader(args.video)
video_writer = None
if args.out:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter(
args.out, fourcc, video_reader.fps,
(video_reader.width, video_reader.height), False)
try:
for frame in track_iter_progress(video_reader):
# インスタンスセグメンテーションを行う
result = inference_detector(model, frame, test_pipeline=test_pipeline)
# インスタンスセグメンテーションの結果からマスクを取り出す
# これはすべてのクラスに対してマスクを取ってきてしまう
# 1クラスのみであればこれで良いが,他クラスの中から特定のマスク動画を取り出したい場合は,工夫しなければならない
result = result.cpu()
pred_instances = result.pred_instances
pred_instances = pred_instances[pred_instances.scores > args.score_thr]
masks = pred_instances.masks
if isinstance(masks, torch.Tensor):
masks = masks.numpy()
elif isinstance(masks, (PolygonMasks, BitmapMasks)):
masks = masks.to_ndarray()
masks = masks.astype(bool)
masks = tensor2ndarray(masks)
masks = masks.astype('uint8') * 255
res = np.zeros_like(frame)
for mask in masks:
rgb = np.zeros_like(frame)
rgb[...] = (1, 1, 1)
rgb = cv2.bitwise_and(rgb, rgb, mask=mask)
res = res + rgb
_, res = cv2.threshold(res, 0, 255, cv2.THRESH_BINARY)
res = res[:, :, 0]
if args.out:
video_writer.write(np.array(res, dtype='uint8'))
finally:
if video_writer:
video_writer.release()
cv2.destroyAllWindows()
if __name__ == '__main__':
main()
ターミナルで以下のコマンドにて使用できる.
$ python createMaskVideo.py [動画ファイルパス] [モデルのconfigファイルパス] [モデルのチェックポイントファイルパス] --out [出力動画ファイルパス]