Python
DeepLearning
TensorFlow

Object detectionで特定の物体を検出してみました

Tensorflowのobject detectionを使って、特定のオブジェクトを検出できるようにやってみました。

環境整備

CPUとGPUの実行速度を見てみたいので、anacondaで二つの仮想環境を作成しました。

conda create -n tfgpu python=3.6 anaconda
activate tfgpu
pip3 install --upgrade tensorflow-gpu
deactivate

conda create -n tfcpu python=3.6 anaconda
activate tfcpu
pip3 install --upgrade tensorflow
deactivate

gpu環境では、cudnn v6のみ対応しているらしいです。
元々cudnn v7が入っているので、v7をアンインストールし、v6を入れ替えました。

git clone https://github.com/tensorflow/models.git
でobject_detectionをダウンロードします。

インストールには、protocが必要となります。
windowsでのインストール方法はこちらを参考にしてください。

  • 拙者の環境
    • GPU: GeForce GTX 1080
    • CPU: Intel(R) Core(TM) i7-7700 CPU @ 3.60GHz
    • MEM: 16G
    • OS: Windows 10 Pro

写真を撮る

まずは、opencv+webカメラを使って、検出したい物体の写真を撮ります。
もちろん、スマホとかを使ってもできますが、写真の転送とかはめんどくさいから、PCで一気にやりました。

自分の場合は、リモートデスクトップを使っているので、PCのウェブカメラを使うためには、RemoteFX USBリダイレクトの設定が必要です。

capture_img.py
cap = cv2.VideoCapture(0)

counter = 1
while True:
  _, frame = cap.read()
  if frame is not None:
    frame = cv2.resize(frame,(inWidth,inHeight))
    cv2.imshow("capture", frame)
    if cv2.waitKey(1) & 0xFF == ord("c"): #cを押せば、キャプチャーする
      img_name = os.path.join(IMAGE_DIR,NAME+"_"+str(counter)+".jpg")
      cv2.imwrite(img_name, frame)
      print("Captured: "+img_name)
      counter+=1
    if cv2.waitKey(1) & 0xFF == ord("q"): #qを押せば、終了する
      break
  else:
    sys.exit(0)

一般的には、100枚前後の画像があれば十分らしいです。

キャプチャ.PNG

データセットを作成

1. opencvを使って、xmlファイルを作成します。
labeling.py
class BndBox:
  def __init__(self, img, fname):
    self.inHeight, self.inWidth = img.shape[:2]
    self.xmin = 0
    self.ymin = 0
    self.xmax = 0
    self.ymax = 0
    self.drawing = False
    self.img = img
    self.fname = fname

  def mouse_event(self, event, x, y, flags, param):
    if event == cv2.EVENT_LBUTTONDOWN:
      self.drawing = True
      self.xmin = x
      self.ymin = y
    elif event == cv2.EVENT_MOUSEMOVE:
      if self.drawing:
        img_copy = self.img.copy()
        cv2.rectangle(img_copy,(self.xmin,self.ymin),(x,y),(0,255,0),1)
        cv2.imshow(self.fname, img_copy)
    elif event == cv2.EVENT_LBUTTONUP:
      self.drawing = False
      self.xmax = x
      self.ymax = y
      img_copy = self.img.copy()
      cv2.rectangle(img_copy,(self.xmin,self.ymin),(x,y),(0,255,0),1)
      cv2.imshow(self.fname, img_copy)

  def clear(self):
    self.xmin = 0
    self.ymin = 0
    self.xmax = 0
    self.ymax = 0
    cv2.imshow(self.fname, self.img)

  def save(self):
    annotation = ET.Element('annotation')

    filename = ET.SubElement(annotation, 'filename')
    filename.text = self.fname + ".jpg"

    size = ET.SubElement(annotation, 'size')
    width = ET.SubElement(size, 'width')
    width.text = str(self.inWidth)
    height = ET.SubElement(size, 'height')
    height.text = str(self.inHeight)
    depth = ET.SubElement(size, 'depth')
    depth.text = "3"

    object = ET.SubElement(annotation, 'object')
    pose = ET.SubElement(object, 'pose')
    pose.text = "Unspecified"
    truncated = ET.SubElement(object, 'truncated')
    truncated.text = "0"
    difficult = ET.SubElement(object, 'difficult')
    difficult.text = "0"
    bndbox = ET.SubElement(object, 'bndbox')
    xmin = ET.SubElement(bndbox, 'xmin')
    xmin.text = str(self.xmin)
    ymin = ET.SubElement(bndbox, 'ymin')
    ymin.text = str(self.ymin)
    xmax = ET.SubElement(bndbox, 'xmax')
    xmax.text = str(self.xmax)
    ymax = ET.SubElement(bndbox, 'ymax')
    ymax.text = str(self.ymax)

    string = ET.tostring(annotation, 'utf-8')
    pretty_string = minidom.parseString(string).toprettyxml(indent='  ')

    xml_file = os.path.join(ANNOTATION_DIR,"xmls",self.fname + '.xml')
    with open(xml_file, 'w') as f:
      f.write(pretty_string)

# Start from here
files = glob.glob(IMAGE_DIR+"/*.jpg")
trainval = []

for f in files:
  img = cv2.imread(f)

  fname = os.path.splitext(os.path.basename(f))[0]
  bndBox = BndBox(img,fname)

  cv2.namedWindow(fname)
  cv2.setMouseCallback(fname, bndBox.mouse_event) #mouse event

  while (True):
    cv2.imshow(fname, img)
    if cv2.waitKey(1) & 0xFF == ord("n"): #nを押せば、xmlファイルを保存し、次画像へ
      bndBox.save()
      trainval.append(fname+" "+"1")
      print(bndBox.xmin,bndBox.ymin,bndBox.xmax,bndBox.ymax,fname+" saved")
      break
    if cv2.waitKey(1) & 0xFF == ord("c"): #cを押せば、クリアする
      bndBox.clear()

  cv2.destroyAllWindows()

txt_file = os.path.join(ANNOTATION_DIR,"trainval.txt")
with open(txt_file, "w", encoding="utf-8") as f: #save trainval.txt
  f.write("\n".join(trainval))

マウスを使って、物体の場所を長方形で囲んで、nを押せば、xmlファイルを作成できます。

キャプチャ2.PNG

作成したxmlファイルはこんな感じです。

test_1.xml
<?xml version="1.0" ?>
<annotation>
  <filename>test_1.jpg</filename>
  <size>
    <width>600</width>
    <height>400</height>
    <depth>3</depth>
  </size>
  <object>
    <pose>Unspecified</pose>
    <truncated>0</truncated>
    <difficult>0</difficult>
    <bndbox>
      <xmin>55</xmin>
      <ymin>91</ymin>
      <xmax>484</xmax>
      <ymax>297</ymax>
    </bndbox>
  </object>
</annotation>

最後に、目次の「trainval.txt」ファイルを生成します。

2. クラスの目次label_map.pbtxtを作成します。

一種だけなので、簡単に作成できます。

label_map.pbtxt
item {
  id: 1
  name: 'test'
}
3. Last

python create_tf_record.py
でデータセットを生成します。

学習

ここからは、object detectionのソースコードを使います。

1. pipelineの設定ファイル

ssd_mobilenet_v1モデルを使います。

2. 学習
python object_detection/train.py 
  --pipeline_config_path=../../data/ssd_mobilenet_v1.config 
  --train_dir=../../data/train 
  --logtostderr
3. 検証
  • 1万回前後:

individualImage.png
まだまだですね。

  • 15万回前後:

individualImage.png
だいぶ良くなりました。

Loss:

キャプチャ.PNG
だいたい4万回以降からLossの変化が少なくなりました。
別データセットに対する学習する際には、8万回まで減少し続ける場合もあります。

実際に使う

1. Freeze Model
python object_detection/export_inference_graph.py
  --input_type image_tensor
  --pipeline_config_path ../../data/ssd_mobilenet_v1.config 
  --trained_checkpoint_prefix ../../data/train/model.ckpt
  --output_directory ../../data/save
2. opencvで使ってみよう
detect.py
detection_graph = tf.Graph()
with detection_graph.as_default():
  od_graph_def = tf.GraphDef()
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
    serialized_graph = fid.read()
    od_graph_def.ParseFromString(serialized_graph)
    tf.import_graph_def(od_graph_def, name='')
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=CLASS_NUM, use_display_name=True)
category_index = label_map_util.create_category_index(categories)

with detection_graph.as_default():
  with tf.Session(graph=detection_graph) as sess:
    # Definite input and output Tensors for detection_graph
    image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
    # Each box represents a part of the image where a particular object was detected.
    detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
    # Each score represent how level of confidence for each of the objects.
    # Score is shown on the result image, together with the class label.
    detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
    detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
    num_detections = detection_graph.get_tensor_by_name('num_detections:0')

    cap = cv.VideoCapture(0)

    while True:
      _, frame = cap.read()
      image_np_expanded = np.expand_dims(frame, axis=0)

      (boxes, scores, classes, num) = sess.run(
          [detection_boxes, detection_scores, detection_classes, num_detections],
          feed_dict={image_tensor: image_np_expanded})

      vis_util.visualize_boxes_and_labels_on_image_array(
          frame,
          np.squeeze(boxes),
          np.squeeze(classes).astype(np.int32),
          np.squeeze(scores),
          category_index,
          use_normalized_coordinates=True,
          line_thickness=8,
          min_score_thresh=.8)

      cv.imshow("detections", frame)

      if cv.waitKey(1) >= 0:
        break

ezgif.com-gif-maker.gif

実は、背景が複雑な場合は、誤認識率も高くなります。
学習データは単純すぎると思います。
白い背景ではなくて、いろんな環境から写真を撮れば、もっといい結果が出るはずです。

CPU vs. GPU

拙者のGPUのメモリは8Gのみなので、学習と検証を同時に実行すれば、固まります。
逆に、CPUのメモリは16Gがあるので、GPUよりかなり遅いですが、同時に処理できます。

それはほっておいて、処理速度を見ってみよう。

GPU:

3.PNG
平均で2.5global_step/secです。

CPU:

キャプチャ.PNG
平均で0.4global_step/secです。

CPUよりGPUが約6倍速いです。

コード

Githubで掲載されました。