LoginSignup
5
5

More than 3 years have passed since last update.

【Windows】【Python】OpenCV3.3.1のdnnモジュールサンプル(mobilenet_ssd_python.py)

Last updated at Posted at 2017-11-24

OpenCV3.3で公式にサポートされたDNN(深層ニューラルネットワーク)モジュールの
Python版mobilenetサンプルを動作させてみました。

学習済みモデルは以下URLにあるはずだが、リンク切れていたため、
 https://github.com/chuanqi305/MobileNet-SSD/blob/master/MobileNetSSD_train.caffemodel
ひとまず、以下のURLから取得。
 https://drive.google.com/file/d/0BwY-lpO6tzxHRHNCdlRKczIzaEU/view?usp=sharing

動画は以下。
CPUのみでもそこそこ動く。
https://www.youtube.com/watch?v=Wg_NU0rYkMI
【【Windows】【Python】OpenCV3.3.1のdnnモジュールサンプル(mobilenet_ssd_python.py)<br>

ソースコードは以下。
基本的には、「opencv\sources\samples\dnn\mobilenet_ssd_python.py」と同じ。
自分の環境構築が悪いのか、そのままで動作しなかったところを一部修正。


#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
import argparse

try:
    #import cv2 as cv
    import cv2                #[修正] cv2呼び出ししているため、as cvを削除
    from cv2 import dnn       #[修正] dnnモジュールをインポート
except ImportError:
    raise ImportError('Can\'t find OpenCV Python module. If you\'ve built it from sources without installation, '
                      'configure environemnt variable PYTHONPATH to "opencv_build_dir/lib" directory (with "python3" subdirectory if required)')

inWidth = 300
inHeight = 300
WHRatio = inWidth / float(inHeight)
inScaleFactor = 0.007843
meanVal = 127.5

classNames = ('background',
              'aeroplane', 'bicycle', 'bird', 'boat',
              'bottle', 'bus', 'car', 'cat', 'chair',
              'cow', 'diningtable', 'dog', 'horse',
              'motorbike', 'person', 'pottedplant',
              'sheep', 'sofa', 'train', 'tvmonitor')

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--video", help="path to video file. If empty, camera's stream will be used")
    parser.add_argument("--prototxt", default="MobileNetSSD_300x300.prototxt",
                        help="path to caffe prototxt")
    parser.add_argument("-c", "--caffemodel", help="path to caffemodel file, download it here: "
                             "https://github.com/chuanqi305/MobileNet-SSD/blob/master/MobileNetSSD_train.caffemodel")
    parser.add_argument("--thr", default=0.2, help="confidence threshold to filter out weak detections")
    args = parser.parse_args()

    net = dnn.readNetFromCaffe(args.prototxt, args.caffemodel)

    #if len(args.video):      #[修正] --video未指定時にlen()で長さ取得できないはずなので修正
    if args.video != None:
        cap = cv2.VideoCapture(args.video)
    else:
        cap = cv2.VideoCapture(0)

    while True:
        # Capture frame-by-frame
        ret, frame = cap.read()
        blob = dnn.blobFromImage(frame, inScaleFactor, (inWidth, inHeight), meanVal)
        net.setInput(blob)
        detections = net.forward()

        cols = frame.shape[1]
        rows = frame.shape[0]

        if cols / float(rows) > WHRatio:
            cropSize = (int(rows * WHRatio), rows)
        else:
            cropSize = (cols, int(cols / WHRatio))

        y1 = (rows - cropSize[1]) / 2
        y2 = y1 + cropSize[1]
        x1 = (cols - cropSize[0]) / 2
        x2 = x1 + cropSize[0]
        frame = frame[y1:y2, x1:x2]

        cols = frame.shape[1]
        rows = frame.shape[0]

        for i in range(detections.shape[2]):
            confidence = detections[0, 0, i, 2]
            if confidence > args.thr:
                class_id = int(detections[0, 0, i, 1])

                xLeftBottom = int(detections[0, 0, i, 3] * cols)
                yLeftBottom = int(detections[0, 0, i, 4] * rows)
                xRightTop   = int(detections[0, 0, i, 5] * cols)
                yRightTop   = int(detections[0, 0, i, 6] * rows)

                cv2.rectangle(frame, (xLeftBottom, yLeftBottom), (xRightTop, yRightTop),
                              (0, 255, 0))
                label = classNames[class_id] + ": " + str(confidence)
                labelSize, baseLine = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)

                cv2.rectangle(frame, (xLeftBottom, yLeftBottom - labelSize[1]),
                                     (xLeftBottom + labelSize[0], yLeftBottom + baseLine),
                                     (255, 255, 255), cv2.FILLED)
                cv2.putText(frame, label, (xLeftBottom, yLeftBottom),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0))

        cv2.imshow("detections", frame)
        if cv2.waitKey(1) >= 0:
            break

以上。

5
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
5