11
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

【Tensorflow-gpu 1.x系】複数のモデルを一つのプログラムで実行する

Last updated at Posted at 2019-08-20

はじめに

以下のように複数のモデル(手検出モデルと指先検出モデル)を1つのプログラムで実行しようとした時にGPUメモリで躓いたためメモ。
以下のプログラムでは手検出で1つ、指先検出で1つの計2つのモデルを利用しています。

躓いたポイント

1つ目のモデルを読み込んだ後、2つ目のモデルを動作させようとするとGPUメモリ不足でプログラムが異常終了する。

原因

TensorFlowはデフォルト設定では、GPUのメモリのほぼ全てを使用して、メモリの断片化を軽減するようにしているようです。
1つ目のモデルの時点でメモリをほぼ全て使い切ったため、2つ目のモデルは動かせなかった模様。

対処

Allowing GPU memory growthオプションを利用し、必要なメモリだけ確保するようにする。
※ただし、プログラム途中でGPUメモリ解放等はしないように注意が必要とのこと

ソースコード

以下のような指定をする。

    config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
    sess = tf.Session(graph=net_graph, config=config)
ソースコード全体のイメージ ※イメージです
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import time
import copy
import cv2 as cv
import tensorflow as tf
import numpy as np


def session_run(sess, inp):
    out = sess.run([
        sess.graph.get_tensor_by_name('num_detections:0'),
        sess.graph.get_tensor_by_name('detection_scores:0'),
        sess.graph.get_tensor_by_name('detection_boxes:0'),
        sess.graph.get_tensor_by_name('detection_classes:0')
    ],
    feed_dict={
        'image_tensor:0':
        inp.reshape(1, inp.shape[0], inp.shape[1], 3)
    })
    return out


def main():
    print("Hand Detection Start...\n")

    # カメラ準備 ##############################################################
    cap = cv.VideoCapture(0)
    cap.set(cv.CAP_PROP_FRAME_WIDTH, 1280)
    cap.set(cv.CAP_PROP_FRAME_HEIGHT, 720)

    # GPUメモリを必要な分だけ確保するよう設定
    config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))

    # 手検出モデルロード ######################################################
    with tf.Graph().as_default() as net1_graph:
        graph_data = tf.gfile.FastGFile('frozen_inference_graph1.pb',
                                        'rb').read()
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(graph_data)
        tf.import_graph_def(graph_def, name='')

    sess1 = tf.Session(graph=net1_graph, config=config)
    sess1.graph.as_default()

    # 指先検出モデルロード ######################################################
    with tf.Graph().as_default() as net2_graph:
        graph_data = tf.gfile.FastGFile('frozen_inference_graph2.pb',
                                        'rb').read()
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(graph_data)
        tf.import_graph_def(graph_def, name='')

    sess2 = tf.Session(graph=net2_graph, config=config)
    sess2.graph.as_default()

    while True:
        start_time = time.time()

        # カメラキャプチャ ####################################################
        ret, frame = cap.read()
        if not ret:
            continue
        debug_image = copy.deepcopy(frame)

        # 手検出実施 ##########################################################
        inp = cv.resize(frame, (512, 512))
        inp = inp[:, :, [2, 1, 0]]  # BGR2RGB

        out = session_run(sess1, inp)

        rows = frame.shape[0]
        cols = frame.shape[1]

        num_detections = int(out[0][0])
        for i in range(num_detections):
            class_id = int(out[3][0][i])
            score = float(out[1][0][i])
            bbox = [float(v) for v in out[2][0][i]]

            if score < 0.8:
                continue

            x = int(bbox[1] * cols)
            y = int(bbox[0] * rows)
            right = int(bbox[3] * cols)
            bottom = int(bbox[2] * rows)

            # 指先検出実施 ####################################################
            if class_id == 3:
                trimming_image = debug_image[y:bottom, x:right]
                inp = cv.resize(trimming_image, (300, 300))
                inp = inp[:, :, [2, 1, 0]]  # BGR2RGB

                f_rows = trimming_image.shape[0]
                f_cols = trimming_image.shape[1]

                out = session_run(sess2, inp)

                f_num_detections = int(out[0][0])
                for i in range(f_num_detections):
                    f_score = float(out[1][0][i])
                    f_bbox = [float(v) for v in out[2][0][i]]

                    f_x = int(f_bbox[1] * f_cols)
                    f_y = int(f_bbox[0] * f_rows)
                    f_right = int(f_bbox[3] * f_cols)
                    f_bottom = int(f_bbox[2] * f_rows)

                    if f_score < 0.4:
                        continue

                    # 指検出結果可視化 ########################################
                    cv.rectangle(
                        trimming_image, (f_x, f_y), (f_right, f_bottom),
                        (0, 255, 0),
                        thickness=2)

            # 手検出結果可視化 ################################################
            cv.rectangle(
                debug_image, (x, y), (right, bottom), (0, 255, 0), thickness=2)

        # 処理時間描画 ########################################################
        elapsed_time = time.time() - start_time
        time_string = u"elapsed time:" + '{:.3g}'.format(elapsed_time)
        cv.putText(debug_image, time_string, (10, 50), cv.FONT_HERSHEY_COMPLEX,
                   1.0, (0, 255, 0))

        # 画面反映 ############################################################
        cv.imshow(' ', debug_image)
        cv.moveWindow(' ', 100, 100)

        key = cv.waitKey(1)
        if key == 27:  # ESC
            break


if __name__ == '__main__':
    main()

以上。

11
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?