はじめに
以下のように複数のモデル(手検出モデルと指先検出モデル)を1つのプログラムで実行しようとした時にGPUメモリで躓いたためメモ。
以下のプログラムでは手検出で1つ、指先検出で1つの計2つのモデルを利用しています。
複数モデル(手検出、指検出)を実行するプログラムを少々書き直し中。。。🐤 pic.twitter.com/tx1fH3Ygkr
— 高橋かずひと@リベロ拝命🐤 (@KzhtTkhs) August 20, 2019
躓いたポイント
1つ目のモデルを読み込んだ後、2つ目のモデルを動作させようとするとGPUメモリ不足でプログラムが異常終了する。
原因
TensorFlowはデフォルト設定では、GPUのメモリのほぼ全てを使用して、メモリの断片化を軽減するようにしているようです。
1つ目のモデルの時点でメモリをほぼ全て使い切ったため、2つ目のモデルは動かせなかった模様。
対処
Allowing GPU memory growthオプションを利用し、必要なメモリだけ確保するようにする。
※ただし、プログラム途中でGPUメモリ解放等はしないように注意が必要とのこと
ソースコード
以下のような指定をする。
config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
sess = tf.Session(graph=net_graph, config=config)
ソースコード全体のイメージ ※イメージです
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import time
import copy
import cv2 as cv
import tensorflow as tf
import numpy as np
def session_run(sess, inp):
out = sess.run([
sess.graph.get_tensor_by_name('num_detections:0'),
sess.graph.get_tensor_by_name('detection_scores:0'),
sess.graph.get_tensor_by_name('detection_boxes:0'),
sess.graph.get_tensor_by_name('detection_classes:0')
],
feed_dict={
'image_tensor:0':
inp.reshape(1, inp.shape[0], inp.shape[1], 3)
})
return out
def main():
print("Hand Detection Start...\n")
# カメラ準備 ##############################################################
cap = cv.VideoCapture(0)
cap.set(cv.CAP_PROP_FRAME_WIDTH, 1280)
cap.set(cv.CAP_PROP_FRAME_HEIGHT, 720)
# GPUメモリを必要な分だけ確保するよう設定
config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
# 手検出モデルロード ######################################################
with tf.Graph().as_default() as net1_graph:
graph_data = tf.gfile.FastGFile('frozen_inference_graph1.pb',
'rb').read()
graph_def = tf.GraphDef()
graph_def.ParseFromString(graph_data)
tf.import_graph_def(graph_def, name='')
sess1 = tf.Session(graph=net1_graph, config=config)
sess1.graph.as_default()
# 指先検出モデルロード ######################################################
with tf.Graph().as_default() as net2_graph:
graph_data = tf.gfile.FastGFile('frozen_inference_graph2.pb',
'rb').read()
graph_def = tf.GraphDef()
graph_def.ParseFromString(graph_data)
tf.import_graph_def(graph_def, name='')
sess2 = tf.Session(graph=net2_graph, config=config)
sess2.graph.as_default()
while True:
start_time = time.time()
# カメラキャプチャ ####################################################
ret, frame = cap.read()
if not ret:
continue
debug_image = copy.deepcopy(frame)
# 手検出実施 ##########################################################
inp = cv.resize(frame, (512, 512))
inp = inp[:, :, [2, 1, 0]] # BGR2RGB
out = session_run(sess1, inp)
rows = frame.shape[0]
cols = frame.shape[1]
num_detections = int(out[0][0])
for i in range(num_detections):
class_id = int(out[3][0][i])
score = float(out[1][0][i])
bbox = [float(v) for v in out[2][0][i]]
if score < 0.8:
continue
x = int(bbox[1] * cols)
y = int(bbox[0] * rows)
right = int(bbox[3] * cols)
bottom = int(bbox[2] * rows)
# 指先検出実施 ####################################################
if class_id == 3:
trimming_image = debug_image[y:bottom, x:right]
inp = cv.resize(trimming_image, (300, 300))
inp = inp[:, :, [2, 1, 0]] # BGR2RGB
f_rows = trimming_image.shape[0]
f_cols = trimming_image.shape[1]
out = session_run(sess2, inp)
f_num_detections = int(out[0][0])
for i in range(f_num_detections):
f_score = float(out[1][0][i])
f_bbox = [float(v) for v in out[2][0][i]]
f_x = int(f_bbox[1] * f_cols)
f_y = int(f_bbox[0] * f_rows)
f_right = int(f_bbox[3] * f_cols)
f_bottom = int(f_bbox[2] * f_rows)
if f_score < 0.4:
continue
# 指検出結果可視化 ########################################
cv.rectangle(
trimming_image, (f_x, f_y), (f_right, f_bottom),
(0, 255, 0),
thickness=2)
# 手検出結果可視化 ################################################
cv.rectangle(
debug_image, (x, y), (right, bottom), (0, 255, 0), thickness=2)
# 処理時間描画 ########################################################
elapsed_time = time.time() - start_time
time_string = u"elapsed time:" + '{:.3g}'.format(elapsed_time)
cv.putText(debug_image, time_string, (10, 50), cv.FONT_HERSHEY_COMPLEX,
1.0, (0, 255, 0))
# 画面反映 ############################################################
cv.imshow(' ', debug_image)
cv.moveWindow(' ', 100, 100)
key = cv.waitKey(1)
if key == 27: # ESC
break
if __name__ == '__main__':
main()
以上。