DataLoaderの挙動
Windows環境下でPyTorchを用いた機械学習を実行していると
num_workers > 0
のケースにおいて iteration 自体は速度が出ていても
DataLoader の epoch開始 が非常に遅いことがある.
この原因としては主に Windows のプロセスの生成が遅いことにあると考えられ、
PyTorch の GitHub をみると同様の問題の Issue も挙げられているようである.
しかしながら WindowsでDeep Learning をやるというマイナー勢力のために
貴重な労力を割かれるということもなく、根本的な対策をとることは考えられていない様子.
ボトルネックは子プロセスのインスタンスを破棄・再生成することにあるとすると
インスタンスを使いまわすようにすれば 生成のオーバーヘッドを回避できると考えられる.
この挙動を制御するためにpersistent_workers
というオプションが存在している.
persistent_workers (bool, optional) – If True, the data loader will not shutdown the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. (default: False)
デフォルトではFalse
これをTrue
指定することで生成したインスタンスを破棄せずに使いまわすようになる.
欠点としては子プロセスを生かしているため使用したメモリも保持したままとなる.
ただ、通常の学習中においては元々消費する想定で動かしているので
学習の際に特殊な処理がない限りはあまり影響もないようにも思える.
※推論サーバーなどとして長時間起動して利用するケースには注意が必要かもしれない
Code Sample
変更点は以下のとおりオプションを追加するにするだけである.
import torch
# train_df は用いる学習データとする.
train_loader = torch.utils.data.DataLoader(
CustomDataset(train_df),
batch_size=32,
shuffle=True,
drop_last=True,
num_workers=4,
+ persistent_workers=True,
)
Ubuntu環境などでは本指定が不要とすると
下記のようにしておくと今後のコード変更が少なく済むと思う.
import os
import torch
train_loader = torch.utils.data.DataLoader(
CustomDataset(train_df),
batch_size=32,
shuffle=True,
drop_last=True,
num_workers=4,
persistent_workers=(os.name == 'nt'),
)
WindowsでPyTorchを使っている方はどれほど居られるかわかりませんが...
なんらかのコードの助けになれば幸いです。