1. Qiita
  2. Items
  3. 機械学習

学習時のバッチデータのio待ちをマルチプロセスにして簡単に早くする

  • 11
    Like
  • 0
    Comment

深層学習を触っていると毎batchのdataを読み込むのio処理を待っている間、gpuを遊ばせていることがたまにあります。
そんなときに、以下のようなコードでマルチプロセス化することで、trainingとio処理を並行して行うことができるので早くすることができます。

loaderの返すデータはpickle化できる必要がありますが、numpyは標準でpickle可能なのでprocess間でやり取りできます。

この方式は簡単ですが、 multiprocessing.Queue().get() にすこしオーバーヘッドがあります。

code

import multiprocessing as mp


class PrefetchQueue(object):
    def __init__(self, loader, path_list, prefetch_num=2):
        self._queue = mp.Queue(prefetch_num)
        self._path_list = path_list
        self._loader = loader
        self._process = mp.Process(target=self._worker)
        self._process.start()

    def _worker(self):
        for path in self._path_list:
            data = self._loader(path)
            self._queue.put(data)
        self._queue.put(None)

    def __iter__(self):
        return iter(self._queue.get, None)

## sample code

import numpy as np
import time

def slow_loader(filename):
  print('start loading', filename)
  time.sleep(3)
  x = np.random.random((10,5))
  y = np.arange(10)
  print('end loading', filename)
  return x,y,filename

filenames = [
  'youjo1',
  'youjo2',
  'youjo3',
  'youjo4',
  'youjo5',
  'youjo6',
]

def main():
  for train in PrefetchQueue(slow_loader, filenames):
    x,y,filename = train
    print('start training', filename)
    print(x,y,filename)
    time.sleep(10) # training time is longer than loading file
    print('end training', filename)

if __name__ == '__main__':
  main()