64
35

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorchのDataLoaderが遅い

Last updated at Posted at 2020-03-13

PyTorchでは、データセットからミニバッチを取り出すのにDataLoader(torch.utils.data.DataLoader)がよく用いられるが、大きなサイズのデータを用いて実験しているときに、PyTorchのDataLoaderを用いると、とても時間がかかることがわかった。比較のためにデータセットからミニバッチを取り出すイタレータを自作し試してみたが、それと比べてもPytorchのDataLoaderはかなり遅いことがわかった。特にこれは、大きなサイズのデータを用いている時にはボトルネックになりうると思われる。

[追記:2020/03/23] 遅い原因はDataLoaderにデフォルトで使われているBatchSamplerであるというコメントをいただきました。詳しくはコメントをご覧ください。

設定

以下では、データ数が100万のlabeltargetからバッチサイズ1万のミニバッチを繰り返し取り出すことを想定する。計算環境はGoogle Colaboratoryを使用した。

import torch

label  = torch.randn(1000000,10)
target = torch.randn(1000000,10)
batch_size = 10000

ミニバッチを取り出すためのloaderを作成し、単にミニバッチを取り出すことを繰り返す以下の関数を用いて実行時間の計測を行う。

def run_loader(loader):
    for label,target in loader:
        pass

PytorchのDataLoader

torch.utils.data.DataLoaderを用いてloaderを作成し(shuffleしない場合)、実行時間を計測すると、6.8秒であった。単にデータを取り出しているだけにしては時間がかかっている感じがする。

dataset = torch.utils.data.TensorDataset(label,target)
loader1 = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=False)

%timeit -n1 -r1 run_loader(loader1)

# 1 loop, best of 1: 6.83 s per loop

shuffleを行う場合は、7.0秒であった。

dataset = torch.utils.data.TensorDataset(label,target)
loader2 = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True)

%timeit -n1 -r1 run_loader(loader2)

# 1 loop, best of 1: 6.97 s per loop

自作のDataLoader

比較のために、データセットからミニバッチを取り出すイタレータを作成し同様に実験を行なってみた。

class DataLoader:

    def __init__(self,dataset,batch_size=1,shuffle=False):
        self.dataset = dataset 
        self.batch_size = batch_size
        self.shuffle = shuffle
        assert all([ dataset[i].size(0) == dataset[0].size(0) for i in range(len(dataset)) ]), 'all the elemtnes must have the same length'
        self.data_size = dataset[0].size(0)

    def __iter__(self):
        self._i = 0
        
        if self.shuffle:
            index_shuffle = torch.randperm(self.data_size)
            self.dataset = [ v[index_shuffle] for v in self.dataset ]

        return self

    def __next__(self):

        i1 = self.batch_size * self._i
        i2 = min( self.batch_size * ( self._i + 1 ), self.data_size )
        
        if i1 >= self.data_size:
            raise StopIteration()

        value = [ v[i1:i2] for v in self.dataset ]

        self._i += 1

        return value

自作のDataLoaderを用いた場合には(shuffleはしない)、実行時間が500マイクロ秒とほとんど取り出しには時間がかからないことがわかる。

loader3 = DataLoader([label,target],batch_size=batch_size,shuffle=False)

%timeit -n1 -r1 run_loader(loader3)

# 1 loop, best of 1: 468 µs per loop

shuffleする場合には実行時間は300ミリ秒であり、しない場合と比べると時間はかかるが、それでもPytorchのDataLoaderを用いた場合に比べると、無視できるほどの時間である。

loader4 = DataLoader([label,target],batch_size=batch_size,shuffle=True)

%timeit -n1 -r1 run_loader(loader4)

# 1 loop, best of 1: 296 ms per loop

まとめ

PyTorchのDataLoaderを用いてミニバッチを取り出す時にとても時間がかかることがわかった。特に大きなサイズのデータを扱う時にはこの影響はとても大きい。

64
35
3

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
64
35

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?