画像処理でGPUを使ったライブラリを動かしたとき、一度に全部データを渡すと落ちちゃう場合がありました。
こんな場合には入力を分割して処理を回しますが、なるべく他の処理を増やしたくありません。
こういった場合の対処方法をメモします。
入出力numpyの関数gpu_process(imgs)を使ってgpu上で画像を処理し、その結果が返されるものとします。
print(imgs.shape) # (NUM_SAMPLE,WIDTH,HEIGHT,3)
outputs = gpu_process(imgs) # ここで落ちる
print(outputs.shape) # (NUM_SAMPLE,WIDTH,HEIGHT,3)
これを下のようにnp.array_split()で分割して渡し、得られたarrayのリストをnp.concatenate()してnumpyに戻します。
print(imgs.shape) # (NUM_SAMPLE,WIDTH,HEIGHT,3)
NUM_SPLIT = 10 # ←増やすほどGPUの処理を軽く出来る
imgs_list = np.array_split(imgs, NUM_SPLIT)
outputs = np.concatenate(
[gpu_process(imgs) for imgs in imgs_list],
axis=0,
)
print(outputs.shape) # (NUM_SAMPLE,WIDTH,HEIGHT,3)