GluonのCVモジュールが便利だったので,ノートPCのwebカメラの画像をCPUで処理してみた.
環境設定
まずはanacondaを入れる.
conda create --name gluon python=3.8
conda activate gluon
python -m pip install mxnet torch torchvision gluoncv
物体検出
モデルはSSD.精度はまあまあ,速度はCPUでも3fps程度.
import mxnet as mx
import time
import gluoncv as gcv
from gluoncv.utils import try_import_cv2
cv2 = try_import_cv2()
net = gcv.model_zoo.get_model(
# good, fast
'ssd_512_mobilenet1.0_coco',
# 'ssd_512_mobilenet1.0_voc',
# 'ssd_512_mobilenet1.0_voc_int8',
#
# 'yolo3_mobilenet1.0_coco',
# 'yolo3_mobilenet1.0_voc',
# too slow...
# 'faster_rcnn_resnet50_v1b_voc', # too slow...
# 'faster_rcnn_fpn_syncbn_resnest50_coco', # too slow...
pretrained=True)
net.hybridize()
cap = cv2.VideoCapture(0)
time.sleep(1)
while(True):
ret, frame = cap.read()
k = cv2.waitKey(1)
if k == ord('q'):
break
frame = mx.nd.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).astype('uint8')
rgb_nd, frame = gcv.data.transforms.presets.ssd.transform_test(
frame, short=512, max_size=700
)
# rgb_nd, frame = gcv.data.transforms.presets.yolo.transform_test(
# frame, short=512, max_size=700
# )
# rgb_nd, frame = gcv.data.transforms.presets.rcnn.transform_test(
# frame, short=512, max_size=700
# )
class_IDs, scores, bounding_boxes = net(rgb_nd)
img = gcv.utils.viz.cv_plot_bbox(frame,
bounding_boxes[0],
scores[0],
class_IDs[0],
class_names=net.classes)
gcv.utils.viz.cv_plot_image(img)
cv2.waitKey(1)
cap.release()
cv2.destroyAllWindows()
姿勢推定
SSDで検出したあとに姿勢推定.
simple poseはResNet18でも152でも4-5fpsで動く.
mobile poseとalpha poseは動作が不安定?
from gluoncv import model_zoo
from gluoncv.data.transforms.pose import (
detector_to_simple_pose,
detector_to_alpha_pose,
heatmap_to_coord
)
import mxnet as mx
import time
import gluoncv as gcv
from gluoncv.utils import try_import_cv2
cv2 = try_import_cv2()
detector = model_zoo.get_model(
'ssd_512_mobilenet1.0_voc',
pretrained=True)
pose_net = model_zoo.get_model(
'simple_pose_resnet18_v1b',
# 'simple_pose_resnet50_v1d',
# 'simple_pose_resnet101_v1b',
# 'simple_pose_resnet152_v1b',
# 'mobile_pose_mobilenetv2_1.0',
# 'mobile_pose_mobilenetv3_small',
# 'mobile_pose_mobilenetv3_large',
# 'alpha_pose_resnet101_v1b_coco',
pretrained=True)
detector.reset_class(["person"],
reuse_weights=['person'])
detector.hybridize()
pose_net.hybridize()
cap = cv2.VideoCapture(0)
time.sleep(1)
while(True):
ret, frame = cap.read()
k = cv2.waitKey(1)
if k == ord('q'):
break
frame = mx.nd.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).astype('uint8')
x, img = gcv.data.transforms.presets.ssd.transform_test(
frame, short=512, max_size=700
)
class_IDs, scores, bounding_boxes = detector(x)
# pose_input, upscale_bbox = \
# detector_to_alpha_pose(img, class_IDs, scores, bounding_boxes)
pose_input, upscale_bbox = \
detector_to_simple_pose(img, class_IDs, scores, bounding_boxes)
predicted_heatmap = pose_net(pose_input)
pred_coords, confidence = heatmap_to_coord(predicted_heatmap, upscale_bbox)
pose_img = gcv.utils.viz.cv_plot_keypoints(img,
pred_coords,
confidence,
class_IDs,
bounding_boxes,
scores,
box_thresh=0.5,
keypoint_thresh=0.2)
cv2.imshow('pose_img', pose_img)
cap.release()
cv2.destroyAllWindows()
セグメンテーション
Mask R-CNNは遅すぎて使い物にならない...
import mxnet as mx
import time
import gluoncv as gcv
from gluoncv.utils import try_import_cv2
cv2 = try_import_cv2()
net = gcv.model_zoo.get_model(
'mask_rcnn_resnet50_v1b_coco', # toooooo slow
pretrained=True)
net.hybridize()
cap = cv2.VideoCapture(0)
time.sleep(1)
while(True):
ret, frame = cap.read()
k = cv2.waitKey(1)
if k == ord('q'):
break
frame = mx.nd.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).astype('uint8')
x, orig_img = gcv.data.transforms.presets.rcnn.transform_test(
frame, short=512, max_size=700
)
y = net(x)
ids, scores, bboxes, masks = [xx[0].asnumpy() for xx in y]
# paint segmentation mask on images directly
width, height = orig_img.shape[1], orig_img.shape[0]
masks, _ = gcv.utils.viz.expand_mask(masks, bboxes, (width, height), scores)
orig_img = gcv.utils.viz.plot_mask(orig_img, masks)
# identical to Faster RCNN object detection
seg_img = gcv.utils.viz.cv_plot_bbox(orig_img,
bboxes,
scores,
ids,
class_names=net.classes)
cv2.imshow('seg_img', seg_img)
cv2.waitKey(1)
cap.release()
cv2.destroyAllWindows()