LoginSignup
7
9

More than 3 years have passed since last update.

YOLOv3で物体検出 with TensorFlow 2 + Colaboratory

Last updated at Posted at 2020-06-28

はじめに

去年からYOLOを使ってシステムを構築していたのですが、TensorFlowのバージョンが2になってから従来の手法(keras-yolo3)が利用できなくなってしまいました。

結果として、TensorFlowのバージョンを1.14に下げて運用するようにしていたのですが、どうしてもTensorFlow2系を使わなくてはならず、方法を探していたらうまくできたのでまとめておきます。

動作環境

ここで紹介した内容は以下のURLにて確認することができます。
https://colab.research.google.com/drive/1IICGm0pA93JKqtvaJumHTW1MGO7_uLi8?usp=sharing

Colaboratoryの確認

新しいノートブックを作成し、以下のコードを実行して環境の確認をします。

Pythonのバージョン

!python -V

実行するとPythonのバージョンが確認できます。

Python 3.6.9

TensorFlowのバージョン

import tensorflow as tf
tf.__version__

実行するとTensorFlowのバージョンが確認できます。

2.2.0

YOLOの準備

# yolov3-tf2のダウンロード
!git clone https://github.com/zzh8829/yolov3-tf2.git

# ダウンロードしたディレクトリ内のファイルを移動
!mv yolov3-tf2/* ./
!rm -R -f yolov3-tf2

# weightsをダウンロード
!wget https://pjreddie.com/media/files/yolov3.weights -O data/yolov3.weights

# weightsをTensorFlow用に変換
!python convert.py --weights ./data/yolov3.weights --output ./checkpoints/yolov3.tf

# 変換後のファイルを確認
!ls -l checkpoints

cxheckpointsに以下のファイルが保存されます。

total 243312
-rw-r--r-- 1 root root        75 Jun 28 10:12 checkpoint
-rw-r--r-- 1 root root 249118743 Jun 28 10:12 yolov3.tf.data-00000-of-00001
-rw-r--r-- 1 root root     24143 Jun 28 10:12 yolov3.tf.index

ライブラリの読込

import time
from absl import app, flags, logging
from absl.flags import FLAGS
import cv2
import numpy as np
import tensorflow as tf
from yolov3_tf2.models import (
    YoloV3, YoloV3Tiny
)
from yolov3_tf2.dataset import transform_images, load_tfrecord_dataset
from yolov3_tf2.utils import draw_outputs

from matplotlib import pyplot as plt

環境設定

flags.DEFINE_string('classes', './data/coco.names', 'path to classes file')
flags.DEFINE_string('weights', './checkpoints/yolov3.tf',
                    'path to weights file')
flags.DEFINE_boolean('tiny', False, 'yolov3 or yolov3-tiny')
flags.DEFINE_integer('size', 416, 'resize images to')
flags.DEFINE_string('image', './imgs/20200628120923255340.jpg', 'path to input image')
flags.DEFINE_string('tfrecord', None, 'tfrecord instead of image')
flags.DEFINE_string('output', './output.jpg', 'path to output image')
flags.DEFINE_integer('num_classes', 80, 'number of classes in the model')

flags.DEFINE_string('f', '', 'kernel')

検出結果の可視化・取得用関数

def show_outputs(img, outputs, class_names):

    boxes, objectness, classes, nums = outputs
    boxes, objectness, classes, nums = boxes[0], objectness[0], classes[0], nums[0]
    wh = np.flip(img.shape[0:2])

    imgs = []

    for i in range(nums):

        pos = (np.array(boxes[i][0:2]) * wh).astype(np.int32)
        x1 = pos[0]
        y1 = pos[1]

        pos = (np.array(boxes[i][2:4]) * wh).astype(np.int32)
        x2 = pos[0]
        y2 = pos[1]

        print(boxes[i])
        print(x1, y1, x2, y2)

        img_dst = img[y1:y2, x1:x2]        
        print(img_dst.shape)

        print(class_names[int(classes[i])])

        plt.imshow(cv2.cvtColor(img_dst, cv2.COLOR_BGR2RGB))
        plt.show()

        imgs.append(img_dst)

    return imgs

物体検出用関数

def main(_argv):

    physical_devices = tf.config.experimental.list_physical_devices('GPU')
    for physical_device in physical_devices:
        tf.config.experimental.set_memory_growth(physical_device, True)

    if FLAGS.tiny:
        yolo = YoloV3Tiny(classes=FLAGS.num_classes)
    else:
        yolo = YoloV3(classes=FLAGS.num_classes)

    yolo.load_weights(FLAGS.weights).expect_partial()
    logging.info('weights loaded')

    class_names = [c.strip() for c in open(FLAGS.classes).readlines()]
    logging.info('classes loaded')

    filenames = ["data/girl.png", "data/street.jpg"]

    for filename in filenames:

        img_raw = tf.image.decode_image(
            open(filename, 'rb').read(), channels=3)

        plt.imshow(img_raw)
        plt.show()

        img = tf.expand_dims(img_raw, 0)
        img = transform_images(img, FLAGS.size)

        t1 = time.time()
        boxes, scores, classes, nums = yolo(img)
        t2 = time.time()
        logging.info('time: {}'.format(t2 - t1))

        logging.info('detections:')
        for i in range(nums[0]):
            logging.info('\t{}, {}, {}'.format(class_names[int(classes[0][i])],
                                               np.array(scores[0][i]),
                                               np.array(boxes[0][i])))

        img = cv2.cvtColor(img_raw.numpy(), cv2.COLOR_RGB2BGR)

        imgs = show_outputs(img, (boxes, scores, classes, nums), class_names)

        i = 0

        for im in imgs:
            output_filename = "./{}_{}.jpg".format(filename.split("/")[-1].split(".")[0], i)
            cv2.imwrite(output_filename, im)

            logging.info('output saved to: {}'.format(output_filename))

            i += 1

実行

try:
    app.run(main)
except SystemExit:
    pass

実行すると以下のように表示されます。

スクリーンショット 2020-06-28 21.26.24.png

できた!

7
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
9