44
50

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Object detectionで特定の物体を検出してみました

Posted at

Tensorflowのobject detectionを使って、特定のオブジェクトを検出できるようにやってみました。

##環境整備

CPUとGPUの実行速度を見てみたいので、anacondaで二つの仮想環境を作成しました。

conda create -n tfgpu python=3.6 anaconda
activate tfgpu
pip3 install --upgrade tensorflow-gpu
deactivate

conda create -n tfcpu python=3.6 anaconda
activate tfcpu
pip3 install --upgrade tensorflow
deactivate

gpu環境では、cudnn v6のみ対応しているらしいです。
元々cudnn v7が入っているので、v7をアンインストールし、v6を入れ替えました。

git clone https://github.com/tensorflow/models.git
でobject_detectionをダウンロードします。

インストールには、protocが必要となります。
windowsでのインストール方法はこちらを参考にしてください。

  • 拙者の環境
  • GPU: GeForce GTX 1080
  • CPU: Intel(R) Core(TM) i7-7700 CPU @ 3.60GHz
  • MEM: 16G
  • OS: Windows 10 Pro

##写真を撮る

まずは、opencv+webカメラを使って、検出したい物体の写真を撮ります。
もちろん、スマホとかを使ってもできますが、写真の転送とかはめんどくさいから、PCで一気にやりました。

自分の場合は、リモートデスクトップを使っているので、PCのウェブカメラを使うためには、RemoteFX USBリダイレクトの設定が必要です。

capture_img.py
cap = cv2.VideoCapture(0)

counter = 1
while True:
  _, frame = cap.read()
  if frame is not None:
    frame = cv2.resize(frame,(inWidth,inHeight))
    cv2.imshow("capture", frame)
    if cv2.waitKey(1) & 0xFF == ord("c"): #cを押せば、キャプチャーする
      img_name = os.path.join(IMAGE_DIR,NAME+"_"+str(counter)+".jpg")
      cv2.imwrite(img_name, frame)
      print("Captured: "+img_name)
      counter+=1
    if cv2.waitKey(1) & 0xFF == ord("q"): #qを押せば、終了する
      break
  else:
    sys.exit(0)

一般的には、100枚前後の画像があれば十分らしいです。

キャプチャ.PNG

##データセットを作成

#####1. opencvを使って、xmlファイルを作成します。

labeling.py
class BndBox:
  def __init__(self, img, fname):
    self.inHeight, self.inWidth = img.shape[:2]
    self.xmin = 0
    self.ymin = 0
    self.xmax = 0
    self.ymax = 0
    self.drawing = False
    self.img = img
    self.fname = fname
    
  def mouse_event(self, event, x, y, flags, param):
    if event == cv2.EVENT_LBUTTONDOWN:
      self.drawing = True
      self.xmin = x
      self.ymin = y
    elif event == cv2.EVENT_MOUSEMOVE:
      if self.drawing:
        img_copy = self.img.copy()
        cv2.rectangle(img_copy,(self.xmin,self.ymin),(x,y),(0,255,0),1)
        cv2.imshow(self.fname, img_copy)
    elif event == cv2.EVENT_LBUTTONUP:
      self.drawing = False
      self.xmax = x
      self.ymax = y
      img_copy = self.img.copy()
      cv2.rectangle(img_copy,(self.xmin,self.ymin),(x,y),(0,255,0),1)
      cv2.imshow(self.fname, img_copy)
      
  def clear(self):
    self.xmin = 0
    self.ymin = 0
    self.xmax = 0
    self.ymax = 0
    cv2.imshow(self.fname, self.img)
    
  def save(self):
    annotation = ET.Element('annotation')
    
    filename = ET.SubElement(annotation, 'filename')
    filename.text = self.fname + ".jpg"
    
    size = ET.SubElement(annotation, 'size')
    width = ET.SubElement(size, 'width')
    width.text = str(self.inWidth)
    height = ET.SubElement(size, 'height')
    height.text = str(self.inHeight)
    depth = ET.SubElement(size, 'depth')
    depth.text = "3"
    
    object = ET.SubElement(annotation, 'object')
    pose = ET.SubElement(object, 'pose')
    pose.text = "Unspecified"
    truncated = ET.SubElement(object, 'truncated')
    truncated.text = "0"
    difficult = ET.SubElement(object, 'difficult')
    difficult.text = "0"
    bndbox = ET.SubElement(object, 'bndbox')
    xmin = ET.SubElement(bndbox, 'xmin')
    xmin.text = str(self.xmin)
    ymin = ET.SubElement(bndbox, 'ymin')
    ymin.text = str(self.ymin)
    xmax = ET.SubElement(bndbox, 'xmax')
    xmax.text = str(self.xmax)
    ymax = ET.SubElement(bndbox, 'ymax')
    ymax.text = str(self.ymax)
    
    string = ET.tostring(annotation, 'utf-8')
    pretty_string = minidom.parseString(string).toprettyxml(indent='  ')
    
    xml_file = os.path.join(ANNOTATION_DIR,"xmls",self.fname + '.xml')
    with open(xml_file, 'w') as f:
      f.write(pretty_string)

# Start from here
files = glob.glob(IMAGE_DIR+"/*.jpg")
trainval = []
  
for f in files:
  img = cv2.imread(f)
    
  fname = os.path.splitext(os.path.basename(f))[0]
  bndBox = BndBox(img,fname)
    
  cv2.namedWindow(fname)
  cv2.setMouseCallback(fname, bndBox.mouse_event) #mouse event
    
  while (True):
    cv2.imshow(fname, img)
    if cv2.waitKey(1) & 0xFF == ord("n"): #nを押せば、xmlファイルを保存し、次画像へ
      bndBox.save()
      trainval.append(fname+" "+"1")
      print(bndBox.xmin,bndBox.ymin,bndBox.xmax,bndBox.ymax,fname+" saved")
      break
    if cv2.waitKey(1) & 0xFF == ord("c"): #cを押せば、クリアする
      bndBox.clear()
      
  cv2.destroyAllWindows()
  
txt_file = os.path.join(ANNOTATION_DIR,"trainval.txt")
with open(txt_file, "w", encoding="utf-8") as f: #save trainval.txt
  f.write("\n".join(trainval))

マウスを使って、物体の場所を長方形で囲んで、nを押せば、xmlファイルを作成できます。

キャプチャ2.PNG

作成したxmlファイルはこんな感じです。

test_1.xml
<?xml version="1.0" ?>
<annotation>
  <filename>test_1.jpg</filename>
  <size>
    <width>600</width>
    <height>400</height>
    <depth>3</depth>
  </size>
  <object>
    <pose>Unspecified</pose>
    <truncated>0</truncated>
    <difficult>0</difficult>
    <bndbox>
      <xmin>55</xmin>
      <ymin>91</ymin>
      <xmax>484</xmax>
      <ymax>297</ymax>
    </bndbox>
  </object>
</annotation>

最後に、目次の「trainval.txt」ファイルを生成します。

#####2. クラスの目次label_map.pbtxtを作成します。
一種だけなので、簡単に作成できます。

label_map.pbtxt
item {
  id: 1
  name: 'test'
}

#####3. Last
python create_tf_record.py
でデータセットを生成します。

##学習

ここからは、object detectionのソースコードを使います。

#####1. pipelineの設定ファイル
ssd_mobilenet_v1モデルを使います。

#####2. 学習

python object_detection/train.py 
  --pipeline_config_path=../../data/ssd_mobilenet_v1.config 
  --train_dir=../../data/train 
  --logtostderr

#####3. 検証

  • 1万回前後:

individualImage.png
まだまだですね。

  • 15万回前後:

individualImage.png
だいぶ良くなりました。

Loss:

キャプチャ.PNG
だいたい4万回以降からLossの変化が少なくなりました。
別データセットに対する学習する際には、8万回まで減少し続ける場合もあります。

##実際に使う

#####1. Freeze Model

python object_detection/export_inference_graph.py
  --input_type image_tensor
  --pipeline_config_path ../../data/ssd_mobilenet_v1.config 
  --trained_checkpoint_prefix ../../data/train/model.ckpt
  --output_directory ../../data/save

#####2. opencvで使ってみよう

detect.py
detection_graph = tf.Graph()
with detection_graph.as_default():
  od_graph_def = tf.GraphDef()
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
    serialized_graph = fid.read()
    od_graph_def.ParseFromString(serialized_graph)
    tf.import_graph_def(od_graph_def, name='')
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=CLASS_NUM, use_display_name=True)
category_index = label_map_util.create_category_index(categories)

with detection_graph.as_default():
  with tf.Session(graph=detection_graph) as sess:
    # Definite input and output Tensors for detection_graph
    image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
    # Each box represents a part of the image where a particular object was detected.
    detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
    # Each score represent how level of confidence for each of the objects.
    # Score is shown on the result image, together with the class label.
    detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
    detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
    num_detections = detection_graph.get_tensor_by_name('num_detections:0')

    cap = cv.VideoCapture(0)
    
    while True:
      _, frame = cap.read()
      image_np_expanded = np.expand_dims(frame, axis=0)
      
      (boxes, scores, classes, num) = sess.run(
          [detection_boxes, detection_scores, detection_classes, num_detections],
          feed_dict={image_tensor: image_np_expanded})
          
      vis_util.visualize_boxes_and_labels_on_image_array(
          frame,
          np.squeeze(boxes),
          np.squeeze(classes).astype(np.int32),
          np.squeeze(scores),
          category_index,
          use_normalized_coordinates=True,
          line_thickness=8,
          min_score_thresh=.8)

      cv.imshow("detections", frame)
      
      if cv.waitKey(1) >= 0:
        break

ezgif.com-gif-maker.gif

実は、背景が複雑な場合は、誤認識率も高くなります。
学習データは単純すぎると思います。
白い背景ではなくて、いろんな環境から写真を撮れば、もっといい結果が出るはずです。

##CPU vs. GPU

拙者のGPUのメモリは8Gのみなので、学習と検証を同時に実行すれば、固まります。
逆に、CPUのメモリは16Gがあるので、GPUよりかなり遅いですが、同時に処理できます。

それはほっておいて、処理速度を見ってみよう。

GPU:

3.PNG
平均で2.5global_step/secです。

CPU:

キャプチャ.PNG
平均で0.4global_step/secです。

CPUよりGPUが約6倍速いです。

##コード
Githubで掲載されました。

44
50
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
44
50

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?