PyTorchでは、データセットからミニバッチを取り出すのにDataLoader(torch.utils.data.DataLoader
)がよく用いられるが、大きなサイズのデータを用いて実験しているときに、PyTorchのDataLoaderを用いると、とても時間がかかることがわかった。比較のためにデータセットからミニバッチを取り出すイタレータを自作し試してみたが、それと比べてもPytorchのDataLoaderはかなり遅いことがわかった。特にこれは、大きなサイズのデータを用いている時にはボトルネックになりうると思われる。
[追記:2020/03/23] 遅い原因はDataLoaderにデフォルトで使われているBatchSamplerであるというコメントをいただきました。詳しくはコメントをご覧ください。
設定
以下では、データ数が100万のlabel
とtarget
からバッチサイズ1万のミニバッチを繰り返し取り出すことを想定する。計算環境はGoogle Colaboratoryを使用した。
import torch
label = torch.randn(1000000,10)
target = torch.randn(1000000,10)
batch_size = 10000
ミニバッチを取り出すためのloader
を作成し、単にミニバッチを取り出すことを繰り返す以下の関数を用いて実行時間の計測を行う。
def run_loader(loader):
for label,target in loader:
pass
PytorchのDataLoader
torch.utils.data.DataLoader
を用いてloaderを作成し(shuffleしない場合)、実行時間を計測すると、6.8秒であった。単にデータを取り出しているだけにしては時間がかかっている感じがする。
dataset = torch.utils.data.TensorDataset(label,target)
loader1 = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=False)
%timeit -n1 -r1 run_loader(loader1)
# 1 loop, best of 1: 6.83 s per loop
shuffleを行う場合は、7.0秒であった。
dataset = torch.utils.data.TensorDataset(label,target)
loader2 = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True)
%timeit -n1 -r1 run_loader(loader2)
# 1 loop, best of 1: 6.97 s per loop
自作のDataLoader
比較のために、データセットからミニバッチを取り出すイタレータを作成し同様に実験を行なってみた。
class DataLoader:
def __init__(self,dataset,batch_size=1,shuffle=False):
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
assert all([ dataset[i].size(0) == dataset[0].size(0) for i in range(len(dataset)) ]), 'all the elemtnes must have the same length'
self.data_size = dataset[0].size(0)
def __iter__(self):
self._i = 0
if self.shuffle:
index_shuffle = torch.randperm(self.data_size)
self.dataset = [ v[index_shuffle] for v in self.dataset ]
return self
def __next__(self):
i1 = self.batch_size * self._i
i2 = min( self.batch_size * ( self._i + 1 ), self.data_size )
if i1 >= self.data_size:
raise StopIteration()
value = [ v[i1:i2] for v in self.dataset ]
self._i += 1
return value
自作のDataLoaderを用いた場合には(shuffleはしない)、実行時間が500マイクロ秒とほとんど取り出しには時間がかからないことがわかる。
loader3 = DataLoader([label,target],batch_size=batch_size,shuffle=False)
%timeit -n1 -r1 run_loader(loader3)
# 1 loop, best of 1: 468 µs per loop
shuffleする場合には実行時間は300ミリ秒であり、しない場合と比べると時間はかかるが、それでもPytorchのDataLoaderを用いた場合に比べると、無視できるほどの時間である。
loader4 = DataLoader([label,target],batch_size=batch_size,shuffle=True)
%timeit -n1 -r1 run_loader(loader4)
# 1 loop, best of 1: 296 ms per loop
まとめ
PyTorchのDataLoaderを用いてミニバッチを取り出す時にとても時間がかかることがわかった。特に大きなサイズのデータを扱う時にはこの影響はとても大きい。