PyTorchのDataLoaderは,マルチプロセスでデータロードが出来る仕組みが備わっている.Windowsでそれを利用しようとしたところ,こちらと同じようなエラーを吐いて動作しなかった.そこで色々調べて解決を行ったため,その方法をメモする.
- 動作環境
- OS: Windows10 Home(バージョン 2004)
- Python: 3.7.3
- PyTorch: 1.3.1
#DataLoaderのマルチプロセス
公式のDocsから引用すると
A
DataLoader
uses single-process data loading by default.
Within a Python process, the Global Interpreter Lock (GIL) prevents true fully parallelizing Python code across threads. To avoid blocking computation code with data loading, PyTorch provides an easy switch to perform multi-process data loading by simply setting the argument
num_workers
to a positive integer.
とのこと.非常にざっくり言うと,DataLoaderクラスのnum_workers
という変数の値を1以上にすれば,データの読み込みを並列化できる,ということである.
#BrokenPipeError
そこで,num_workers
を1以上の値に設定して動かしたところ,
BrokenPipeError: [Errno 32] Broken pipe
とエラーを吐いて動作しなかった.
PytorchのDatasetをDataLoaderで並列に読み込みたいときのエラー(Windows)を参考に,Dataset
の定義を別のファイルで行っても同様のエラーが発生した.
#解決法
こちらを参考にすると,どうやらWindowsでマルチプロセスを行う場合,if __name__ == "__main__"
内でマルチプロセスを行う関数を実行しなければならないとのこと.
修正前
from torch.utils.data import DataLoader
from dataloader import MyDataset #作成したdataset
def train():
dataset = MyDataset()
train_loader = DataLoader(dataset, num_workers=2, shuffle=True,
batch_size=4,
pin_memory=True,
drop_last=True)
for batch in train_loader:
#do some process...
if __name__ == "__main__":
train()
修正後
from torch.utils.data import DataLoader
from dataloader import MyDataset #作成したdataset
def train(train_loader):
for batch in train_loader:
#do some process...
if __name__ == "__main__":
#dataset, DataLoaderを移動
dataset = MyDataset()
train_loader = DataLoader(dataset, num_workers=2, shuffle=True,
batch_size=4,
pin_memory=True,
drop_last=True)
train(train_loader)
DataLoaderの場合,インスタンスの生成をif __name__ == "__main__"
内で行っていれば,データの読み込み自体は別の関数内で実行してもマルチプロセスは動作した.
#まとめ
Windows環境でDataLoaderの並列化するためのメモを記した.深層学習周りでは,Windowsだと動作しないまたは少し工夫をしないと行けないといった作業が多く大変である.なので,Windows周りで発生するエラーは定期的に記事にしてまとめていきたい所存