Pytorchで少し凝った入出力のNNを作成するときには、既存のDatasetで対応できないことがあります。その際にはtorch.utils.data.dataset
を継承する形で自作のDatasetを作成するのですが、そこで乱数を使っていると意図しない挙動をするケースがあったので、書き残しておきます。乱数と再現性は本当に難しいですね……。
環境
- Python 3.6.3
- Pytorch 0.4.1
- macOS High Sierra 10.13.6
問題
ここでは例として、以下のようにnp.random.uniform(-1, 1)
の値を返すDatasetを考えます。想定としては、このDatasetから値を取り出すたびに、-1から1の間の値がランダムに返ってきてほしいわけです。
import torch
from torch.utils.data.dataset import Dataset
class RandomDataset(Dataset):
def __init__(self):
pass
def __getitem__(self, index):
return np.random.uniform(-1, 1)
def __len__(self):
return 10
ではこれを単純にDataLoaderで取り出すとどうなるでしょうか。
dataset = RandomDataset()
dataloader = torch.utils.data.DataLoader(dataset,
shuffle=True,
num_workers=4,
batch_size=1)
In [ ]: for i, n in enumerate(dataloader):
...: print(i, n)
0 tensor([-0.4910], dtype=torch.float64)
1 tensor([-0.4910], dtype=torch.float64)
2 tensor([-0.4910], dtype=torch.float64)
3 tensor([-0.4910], dtype=torch.float64)
4 tensor([-0.5297], dtype=torch.float64)
5 tensor([-0.5297], dtype=torch.float64)
6 tensor([-0.5297], dtype=torch.float64)
7 tensor([-0.5297], dtype=torch.float64)
8 tensor([-0.2516], dtype=torch.float64)
9 tensor([-0.2516], dtype=torch.float64)
毎回ランダムにDatasetから取得していたはずが、どうも一定の回数だけ同じ値が帰ってきています。よく見るとnum_workers
の数だけ重複しています。
DataLoaderの仕様
この問題は、公式ドキュメントでも触れられています。
- Frequently Asked Questions — PyTorch master documentation
- torch.utils.data — PyTorch master documentation
これによると、workerの分だけforkするときに同一のnumpyのRandom Stateの状態がコピーされるようです。そのため、Datasetの中でランダムな処理を書いていたとしても、それぞれのworkderで同じ結果が返されていたわけです。
対応方法
Dataset側を変えずに対応するにはDataLoader
のworker_init_fn
パラメータでworkerの初期化時にシードを指定します。一番単純な方法としては、np.random.seed()
をworkerごとに実行する方法です。
np.random generates the same random numbers for each data batch · Issue #5059 · pytorch/pytorch
dataset = RandomDataset()
dataloader = torch.utils.data.DataLoader(dataset,
shuffle=True,
num_workers=4,
batch_size=1,
worker_init_fn=lambda x: np.random.seed())
これで、Datasetから取り出すたびに違う値を得ることができました。ただし、これは完全にランダムになってしまうので、DataLoaderから値を取り出す操作をするたびに異なる結果になります。
In [ ]: for i, n in enumerate(dataloader):
...: print(i, n)
0 tensor([-0.0622], dtype=torch.float64)
1 tensor([0.2198], dtype=torch.float64)
2 tensor([-0.5684], dtype=torch.float64)
3 tensor([0.7973], dtype=torch.float64)
4 tensor([-0.5899], dtype=torch.float64)
5 tensor([-0.6546], dtype=torch.float64)
6 tensor([0.0219], dtype=torch.float64)
7 tensor([-0.0046], dtype=torch.float64)
8 tensor([-0.0272], dtype=torch.float64)
9 tensor([-0.8230], dtype=torch.float64)
またはworker_idを使ってシードをばらけさせつつ再現性を担保するやり方もあります。この方法だと、スクリプトの最初でシードを固定しておけば、workerごとに異なるランダム出力を得つつ実行結果が毎回同じようなDataLoaderが作成できます。
def worker_init_fn(worker_id):
np.random.seed(np.random.get_state()[1][0] + worker_id)
np.random.seed(42)
dataset = RandomDataset()
dataloader = torch.utils.data.DataLoader(dataset,
shuffle=True,
num_workers=4,
batch_size=1,
worker_init_fn=worker_init_fn)
参考
-
pytorchでCNNのlossが毎回変わる問題の対処法 (on cpu) - Qiita
- 再現性のために複数workerでシードを固定する際に、同様に
worker_init_fn
を使っています
- 再現性のために複数workerでシードを固定する際に、同様に
- Does __getitem__ of dataloader reset random seed? - PyTorch Forums
-
utkuozbulak/pytorch-custom-dataset-examples: Some custom dataset examples for PyTorch
- custom datasetの作り方