LoginSignup
38
37

More than 3 years have passed since last update.

TensorFlow Object Detection API のつかいかた(推論。Colabサンプル付き)

Last updated at Posted at 2020-11-21

記事を読めば、TensorFlow Object Detection API (推論部分)のつかいかたがわかります。
Colabでできます。

Colabサンプル

サンプルのセルを実行していくと、TensorFlowObjectDetectionAPIが試せます。
最後のセルのImage_Pathを自前の画像のものに変えると、自前の画像で物体検出できます。

TensorFlow公式 Model Zooにはいろんな種類のモデルがあります!

(モデルのトレーニングはこちら)
TensorFlow Object Detection API で物体検出モデルをトレーニング

(数枚の画像でできる簡易トレーニングはこちら)
TensorFlow Object Detection APIで物体検出モデルを簡易トレーニング

ダウンロード.png

手順

0.TensorFlow2をインストール

!pip install -U --pre tensorflow=="2.2.0"

1.TensorFlow 公式 Models をGitHubからClone

import os
import pathlib

# 現在のディレクトリパスにmodelsが含まれていれば、そこに移動する。なければクローンする。
if "models" in pathlib.Path.cwd().parts:
  while "models" in pathlib.Path.cwd().parts:
    os.chdir('..')
elif not pathlib.Path('models').exists():
  !git clone --depth 1 https://github.com/tensorflow/models

2.Object Detection API、必要なモジュールをインストール

%%bash #bashコマンドを有効に
cd models/research/
protoc object_detection/protos/*.proto --python_out=.
cp object_detection/packages/tf2/setup.py .
python -m pip install . 

3.モジュールのインポート

import matplotlib
import matplotlib.pyplot as plt

import io
import scipy.misc
import numpy as np
from six import BytesIO
from PIL import Image, ImageDraw, ImageFont

import tensorflow as tf

from object_detection.utils import label_map_util
from object_detection.utils import config_util
from object_detection.utils import visualization_utils as viz_utils
from object_detection.builders import model_builder

%matplotlib inline

4.画像読み込み関数

def load_image_into_numpy_array(path):
  """画像を読み込んでNumpy Arrayにする

  Puts image into numpy array to feed into tensorflow graph.
  Note that by convention we put it into a numpy array with shape
  (height, width, channels), where channels=3 for RGB.

  Args:
    path: the file path to the image

  Returns:
    uint8 numpy array with shape (img_height, img_width, 3)
  """
  img_data = tf.io.gfile.GFile(path, 'rb').read()
  image = Image.open(BytesIO(img_data))
  (im_width, im_height) = image.size
  return np.array(image.getdata()).reshape(
      (im_height, im_width, 3)).astype(np.uint8)

def get_keypoint_tuples(eval_config):
  """Return a tuple list of keypoint edges from the eval config.

  Args:
    eval_config: an eval config containing the keypoint edges

  Returns:
    a list of edge tuples, each in the format (start, end)
  """
  tuple_list = []
  kp_list = eval_config.keypoint_edge
  for edge in kp_list:
    tuple_list.append((edge.start, edge.end))
  return tuple_list

5.モデルをダウンロード

!wget http://download.tensorflow.org/models/object_detection/tf2/20200713/centernet_hg104_512x512_coco17_tpu-8.tar.gz
!tar -xf centernet_hg104_512x512_coco17_tpu-8.tar.gz

公式 Model Zoo から好きなモデルをダウンロードします。
ダウンロードURLは上記リンクのモデル名にカーソルを合わせると表示されます。

スクリーンショット 2020-11-21 1.30.20.png

性能比較を見てるだけでたのしいですね。
ダウンロード・解凍が完了すると、checkpoint、saved_model、pipline.configを含むフォルダが展開されます。
スクリーンショット 2020-11-21 1.42.25.png

6.pipeline config(モデルの構成情報)を読み込んで、モデルをビルドする

#構成情報ファイルのパス。リポジトリにConfigファイルの揃ったフォルダがあるけど、微妙にモデル名が省略されていたりするので、ダウンロードしたものの方が確実?
pipeline_config = "./centernet_hg104_512x512_coco17_tpu-8/pipeline.config"
#チェックポイントのパス 
model_dir = "./centernet_hg104_512x512_coco17_tpu-8/checkpoint"

#モデル構成情報読み込み
configs = config_util.get_configs_from_pipeline_file(pipeline_config)
model_config = configs['model']

#読み込んだ構成情報でモデルをビルド
detection_model = model_builder.build(
      model_config=model_config, is_training=False)

#チェックポイントから重みを復元
ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
ckpt.restore(os.path.join(model_dir, 'ckpt-0')).expect_partial()

7.モデルによる推論関数を準備

def get_model_detection_function(model):
  """Get a tf.function for detection."""

  @tf.function
  def detect_fn(image):
    """Detect objects in image."""

    image, shapes = model.preprocess(image)
    prediction_dict = model.predict(image, shapes)
    detections = model.postprocess(prediction_dict, shapes)

    return detections, prediction_dict, tf.reshape(shapes, [-1])

  return detect_fn

detect_fn = get_model_detection_function(detection_model)

8.ラベルを準備

物体検出の推論ではモデルをトレーニングをした物体ラベルが必要です。
ラベルは公式リポジトリの models/research/object_detection/data/ にあります。今回のモデルはCocoデータセットでトレーニングされているので、 mscoco_label_map.pbtxt を使います。

label_map_path = './models/research/object_detection/data/mscoco_label_map.pbtxt'
label_map = label_map_util.load_labelmap(label_map_path)
categories = label_map_util.convert_label_map_to_categories(
    label_map,
    max_num_classes=label_map_util.get_max_label_map_index(label_map),
    use_display_name=True)
category_index = label_map_util.create_category_index(categories)
label_map_dict = label_map_util.get_label_map_dict(label_map, use_display_name=True)

9.用意した画像で物体検出の実行

Colabに好きな画像をアップロードして、image_pathにそのパスを指定します。
ちなみに、アルファチャネル付きの画像は3チャネルに直してから実行する必要があるっぽいです。

image_dir = 'models/research/object_detection/test_images/'
image_path = os.path.join(image_dir, 'image2.jpg')
image_np = load_image_into_numpy_array(image_path)

# Things to try:
# Flip horizontally
# image_np = np.fliplr(image_np).copy()

# Convert image to grayscale
# image_np = np.tile(
#     np.mean(image_np, 2, keepdims=True), (1, 1, 3)).astype(np.uint8)

input_tensor = tf.convert_to_tensor(
    np.expand_dims(image_np, 0), dtype=tf.float32)
detections, predictions_dict, shapes = detect_fn(input_tensor)

label_id_offset = 1
image_np_with_detections = image_np.copy()

# Use keypoints if available in detections
keypoints, keypoint_scores = None, None
if 'detection_keypoints' in detections:
  keypoints = detections['detection_keypoints'][0].numpy()
  keypoint_scores = detections['detection_keypoint_scores'][0].numpy()

viz_utils.visualize_boxes_and_labels_on_image_array(
      image_np_with_detections,
      detections['detection_boxes'][0].numpy(),
      (detections['detection_classes'][0].numpy() + label_id_offset).astype(int),
      detections['detection_scores'][0].numpy(),
      category_index,
      use_normalized_coordinates=True,
      max_boxes_to_draw=200,
      min_score_thresh=.30,
      agnostic_mode=False,
      keypoints=keypoints,
      keypoint_scores=keypoint_scores,
      keypoint_edges=get_keypoint_tuples(configs['eval_config']))

plt.figure(figsize=(12,16))
plt.imshow(image_np_with_detections)
plt.show()

ダウンロード.png

ボックス、ラベル、信頼度が表示されます。
スクリーンショット 2020-11-21 19.12.21.png

🐣


お仕事のご相談こちらまで
rockyshikoku@gmail.com

Core MLを使ったアプリを作っています。
機械学習関連の情報を発信しています。

Twitter
Medium
ホームページ

38
37
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
38
37