2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

TensorFlow Hub にあるモジュールで物体検出をした時のメモ

Last updated at Posted at 2019-03-14

内容

  • TensorFlow Hub にある Faster_RCNN と SSD のモジュール(学習済みモデル+パラメータ)の使い方

レファレンス

TensorFlow Hub is a library for reusable machine learning modules.

faster_rcnn/openimages_v4/inception_resnet_v2

openimages_v4/ssd/mobilenet_v2

TensorFlow Hub のインポート

import tensorflow as tf
import tensorflow_hub as hub

モジュールの選択

  • FasterRCNN+InceptionResNet V2: high accuracy
  • ssd+mobilenet V2: small and fast
detector = hub.Module("https://tfhub.dev/google/faster_rcnn/openimages_v4/inception_resnet_v2/1")
# detector = hub.Module("https://tfhub.dev/google/openimages_v4/ssd/mobilenet_v2/1")

画像データの取り込みと変換

テンソルの形と値に注意。

Inputs
A three-channel image of variable size - the model does NOT support batching. The input tensor is a tf.float32 tensor with shape [1, height, width, 3] with values in [0.0, 1.0].

import cv2

path = '../'
file = 'xxx.jpg'

im_bgr = cv2.imread(path+file)
im = cv2.cvtColor(im_bgr, cv2.COLOR_BGR2RGB)

image_original = im.copy()
image_np = im.copy() / 255.0

image_tensor = tf.convert_to_tensor(np.expand_dims(image_np, axis=0), dtype=tf.float32)

image.png

推測

実行するために必要なコードは実質この5行。

detector_output = detector(image_tensor, as_dict=True)

init_ops = [tf.global_variables_initializer(), tf.tables_initializer()]

with tf.Session() as sess:
  sess.run(init_ops)

  result_out = sess.run(detector_output)

結果

作図するための関数。

def draw_bounding_box_on_image(image,
                               ymin,
                               xmin,
                               ymax,
                               xmax,
                               color,
                               font,
                               thickness=2,
                               display_str_list=()):
  """Adds a bounding box to an image."""
  draw = ImageDraw.Draw(image)
  im_width, im_height = image.size
  (left, right, top, bottom) = (xmin * im_width, xmax * im_width,
                                ymin * im_height, ymax * im_height)
  draw.line([(left, top), (left, bottom), (right, bottom), (right, top),
             (left, top)],
            width=thickness,
            fill=color)

  # If the total height of the display strings added to the top of the bounding
  # box exceeds the top of the image, stack the strings below the bounding box
  # instead of above.
  display_str_heights = [font.getsize(ds)[1] for ds in display_str_list]
  # Each display_str has a top and bottom margin of 0.05x.
  total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights)

  if top > total_display_str_height:
    text_bottom = top
  else:
    text_bottom = bottom + total_display_str_height
  # Reverse list and print from bottom to top.
  for display_str in display_str_list[::-1]:
    text_width, text_height = font.getsize(display_str)
    margin = np.ceil(0.05 * text_height)
    draw.rectangle([(left, text_bottom - text_height - 2 * margin),
                    (left + text_width, text_bottom)],
                   fill=color)
    draw.text((left + margin, text_bottom - text_height - margin),
              display_str,
              fill="black",
              font=font)
    text_bottom -= text_height - 2 * margin


def draw_boxes(image, boxes, class_names, scores, max_boxes=10, min_score=0.1):
  """Overlay labeled boxes on an image with formatted scores and label names."""
  colors = list(ImageColor.colormap.values())

  #try:
  #  font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationSansNarrow-Regular.ttf", 25)
  #except IOError:
  #  print("Font not found, using default font.")
  #  font = ImageFont.load_default()
  font = ImageFont.load_default()
  
  for i in range(min(boxes.shape[0], max_boxes)):
    if scores[i] >= min_score:
      ymin, xmin, ymax, xmax = tuple(boxes[i].tolist())
      display_str = "{}: {}%".format(class_names[i].decode("ascii"),
                                     int(100 * scores[i]))
      color = colors[hash(class_names[i]) % len(colors)]
      image_pil = Image.fromarray(np.uint8(image)).convert("RGB")
      draw_bounding_box_on_image(
          image_pil,
          ymin,
          xmin,
          ymax,
          xmax,
          color,
          font,
          display_str_list=[display_str])
      np.copyto(image, np.array(image_pil))
  return image

作図。



boxes = result_out['detection_boxes']
class_entities = result_out['detection_class_entities']
class_names = result_out['detection_class_names']
class_labels = result_out['detection_class_labels']
scores = result_out['detection_scores']

print ("Found %d objects." % len(scores))

image_with_boxes = draw_boxes2(im, boxes, class_entities, scores)
print ('shape: ', image_with_boxes.shape)


fig = plt.figure(figsize = (15, 10))
ax1 = fig.add_subplot(1, 2, 1)
ax1.imshow(image_original)
ax1.set_xticks([])
ax1.set_yticks([])
ax1.set_title('Original')

ax2 = fig.add_subplot(1, 2, 2)
ax2.imshow(image_with_boxes)
ax2.set_xticks([])
ax2.set_yticks([])
ax2.set_title('With boxes')

plt.show()

image.png

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?