LoginSignup
6
2

More than 3 years have passed since last update.

TensorFlow Object Detection API で独自のデータセットをつかうには、TFRecord ファイル形式にする必要があります。
データセットからTFRecordを作る手順です。

手順

tf_exampleを生成するコードスニペットを使うためにobject_detectionAPI をクローンします。


import os
import pathlib

if "models" in pathlib.Path.cwd().parts:
  while "models" in pathlib.Path.cwd().parts:
    os.chdir('..')
elif not pathlib.Path('models').exists():
  !git clone --depth 1 https://github.com/tensorflow/models

%%bash
cd models/research/
protoc object_detection/protos/*.proto --python_out=.
cp object_detection/packages/tf2/setup.py .
python -m pip install .

個々のデータを tf_example に変換する関数を定義

個々の画像とアノテーション情報をバイトに変換して、
tf_exampleという形式に変えます。

import tensorflow as tf
from object_detection.utils import dataset_util

def create_tf_example(height,
                      width,
                      filename,
                      image_format,
                      xmins,xmaxs,
                      ymins,
                      ymaxs,
                      classes_text,
                      classes):
  # TODO(user): Populate the following variables from your example.
  # height = None # Image height
  # width = None # Image width
  # filename = None # Filename of the image. Empty if image is not from file
  # encoded_image_data = None # Encoded image bytes
  # image_format = None # b'jpeg' or b'png'

  # xmins = [] # List of normalized left x coordinates in bounding box (1 per box)
  # xmaxs = [] # List of normalized right x coordinates in bounding box
  #            # (1 per box)
  # ymins = [] # List of normalized top y coordinates in bounding box (1 per box)
  # ymaxs = [] # List of normalized bottom y coordinates in bounding box
  #            # (1 per box)
  # classes_text = [] # List of string class name of bounding box (1 per box)
  # classes = [] # List of integer class id of bounding box (1 per box)

  with tf.io.gfile.GFile(filename, 'rb') as fid:
      encoded_jpg = fid.read()
      # encoded_jpg_io = io.BytesIO(encoded_jpg)

  tf_example = tf.train.Example(features=tf.train.Features(feature={
      'image/height': dataset_util.int64_feature(height),
      'image/width': dataset_util.int64_feature(width),
      'image/filename': dataset_util.bytes_feature(filename.encode('utf-8')),
      'image/source_id': dataset_util.bytes_feature(filename.encode('utf-8')),
      'image/encoded': dataset_util.bytes_feature(encoded_jpg),
      'image/format': dataset_util.bytes_feature(image_format),
      'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
      'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
      'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
      'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
      'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
      'image/object/class/label': dataset_util.int64_list_feature(classes),
  }))
  return tf_example

データセットをForLoop処理でtf_exampleにしてTFRecordWriterで書き込む

以下のようなアノテーションデータがあるとします。
boxは [minx, miny, maxx, maxy]


{
   "categories": [
      {
         "id": 1,
         "name": "cat"
      },
      {
         "id": 2,
         "name": "dog"
      }
   ],
   "annotations": [
      {
         "filename": "train_000.jpg",
         "image_height": 3840,
         "image_width": 2160,
         "labels": [
            1,
            1,
            2
         ],
         "label_texts": [
            "cat",
            "cat",
            "dog"
         ],
         "boxes": [
            [
               1250,
               790,
               1850,
               1300
            ],
            [
               920,
               1230,
               1310,
               1550
            ],
            [
               12,
               1180,
               550,
               1450
            ]
         ]
      },
...
      }
   ]
}

データセットを1画像分ずつ tf_example にして tf_records に書き込みます。
tf_example の中身の tf.train.Feature はバイトしか受け付けないので、データをバイトにして与える必要があります。

import tensorflow as tf
import os
import numpy as np
from PIL import Image



# from object_detection.utils import dataset_util

output_path = './data.tfrecords'
image_dir = './train_images/'
writer = tf.io.TFRecordWriter(output_path)

annotations = dataset['annotations']

for annotation in annotations:
   if annotation['boxes'] != []:
       height = annotation['image_height']
       width = annotation['image_width']
       filename = (image_dir + annotation['filename']).encode('utf-8')
       image_format = b'jpeg'

       xmins = []
       xmaxs = []
       ymins = []
       ymaxs = []

       for box in annotation['boxes']:
           xmins.append(box[0] / width) # 0~1に正規化
           xmaxs.append(box[2] / width)
           ymins.append(box[1] / height)
           ymaxs.append(box[3] / height)

       classes_text = []

       for text in annotation['label_texts']:
           classes_text.append(text.encode('utf-8'))
       classes = []
       for label in annotation['labels']:
           classes.append(bytes([label]))

       tf_example = create_tf_example(height,width,filename,image_format,xmins,xmaxs,ymins,ymaxs,classes_text,classes)
       writer.write(tf_example.SerializeToString())
   writer.close()

分割して書き込む

データセットが大きい場合は、TFRecordを分割してファイルにすると便利です。
公式のドキュメントによると

tf.data.Dataset APIは、入力例を並行して読み取ることができ、スループットを向上させます。
tf.data.Dataset APIは、モデルのパフォーマンスをわずかに向上させるシャードファイルを使用して例をより適切にシャッフルできます。

tf_example を生成し、分割して書き込みます。

import contextlib2
from object_detection.dataset_tools import tf_record_creation_util

num_shards=10
output_filebase='./train_dataset.record'

with contextlib2.ExitStack() as tf_record_close_stack:
  output_tfrecords = tf_record_creation_util.open_sharded_output_tfrecords(
      tf_record_close_stack, output_filebase, num_shards)
  annotations = dataset['annotations']

  for i in range(len(annotations)):
     if annotations[i]['boxes'] != []:
        height = annotations[i]['image_height']
        width = annotations[i]['image_width']
        filename = (image_dir + annotations[i]['filename']).encode('utf-8')
        image_format = b'jpeg'

        xmins = []
        xmaxs = []
        ymins = []
        ymaxs = []

        for box in annotations[i]['boxes']:
            xmins.append(box[0] / width) # 0~1に正規化
            xmaxs.append(box[2] / width)
            ymins.append(box[1] / height)
            ymaxs.append(box[3] / height)

        classes_text = []

        for text in annotations[i]['label_texts']:
            classes_text.append(text.encode('utf-8'))
        classes = []
        for label in annotations[i]['labels']:
            classes.append(bytes([label]))
        tf_example = create_tf_example(height,width,filename,image_format,xmins,xmaxs,ymins,ymaxs,classes_text,classes)
        output_shard_index = i % num_shards
        output_tfrecords[output_shard_index].write(tf_example.SerializeToString())

分割したファイルが生成されます。

./train_dataset.record-00000-00010
./train_dataset.record-00001-00010
...
./train_dataset.record-00009-00010

使用するときはConfigを以下に設定します

tf_record_input_reader { 
  input_path" /path/to/train_dataset.record-?????-of-00010 " 
}

🐣


フリーランスエンジニアです。
お仕事のご相談こちらまで
rockyshikoku@gmail.com

Core MLを使ったアプリを作っています。
機械学習関連の情報を発信しています。

Twitter
Medium

6
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
2