41
27

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Pytorchにおいて自作Dataset内で乱数を使うときの注意点

Last updated at Posted at 2018-08-17

Pytorchで少し凝った入出力のNNを作成するときには、既存のDatasetで対応できないことがあります。その際にはtorch.utils.data.datasetを継承する形で自作のDatasetを作成するのですが、そこで乱数を使っていると意図しない挙動をするケースがあったので、書き残しておきます。乱数と再現性は本当に難しいですね……。

環境

  • Python 3.6.3
  • Pytorch 0.4.1
  • macOS High Sierra 10.13.6

問題

ここでは例として、以下のようにnp.random.uniform(-1, 1)の値を返すDatasetを考えます。想定としては、このDatasetから値を取り出すたびに、-1から1の間の値がランダムに返ってきてほしいわけです。

import torch
from torch.utils.data.dataset import Dataset

class RandomDataset(Dataset):
    def __init__(self):
        pass

    def __getitem__(self, index):
        return np.random.uniform(-1, 1)

    def __len__(self):
        return 10

ではこれを単純にDataLoaderで取り出すとどうなるでしょうか。

dataset = RandomDataset()
dataloader = torch.utils.data.DataLoader(dataset,
                                         shuffle=True,
                                         num_workers=4,
                                         batch_size=1)
In [ ]: for i, n in enumerate(dataloader):
   ...:     print(i, n)

0 tensor([-0.4910], dtype=torch.float64)
1 tensor([-0.4910], dtype=torch.float64)
2 tensor([-0.4910], dtype=torch.float64)
3 tensor([-0.4910], dtype=torch.float64)
4 tensor([-0.5297], dtype=torch.float64)
5 tensor([-0.5297], dtype=torch.float64)
6 tensor([-0.5297], dtype=torch.float64)
7 tensor([-0.5297], dtype=torch.float64)
8 tensor([-0.2516], dtype=torch.float64)
9 tensor([-0.2516], dtype=torch.float64)

毎回ランダムにDatasetから取得していたはずが、どうも一定の回数だけ同じ値が帰ってきています。よく見るとnum_workersの数だけ重複しています。

DataLoaderの仕様

この問題は、公式ドキュメントでも触れられています。

これによると、workerの分だけforkするときに同一のnumpyのRandom Stateの状態がコピーされるようです。そのため、Datasetの中でランダムな処理を書いていたとしても、それぞれのworkderで同じ結果が返されていたわけです。

対応方法

Dataset側を変えずに対応するにはDataLoaderworker_init_fnパラメータでworkerの初期化時にシードを指定します。一番単純な方法としては、np.random.seed()をworkerごとに実行する方法です。

np.random generates the same random numbers for each data batch · Issue #5059 · pytorch/pytorch

dataset = RandomDataset()
dataloader = torch.utils.data.DataLoader(dataset,
                                         shuffle=True,
                                         num_workers=4,
                                         batch_size=1,
                                         worker_init_fn=lambda x: np.random.seed())

これで、Datasetから取り出すたびに違う値を得ることができました。ただし、これは完全にランダムになってしまうので、DataLoaderから値を取り出す操作をするたびに異なる結果になります。

In [ ]: for i, n in enumerate(dataloader):
   ...:     print(i, n)

0 tensor([-0.0622], dtype=torch.float64)
1 tensor([0.2198], dtype=torch.float64)
2 tensor([-0.5684], dtype=torch.float64)
3 tensor([0.7973], dtype=torch.float64)
4 tensor([-0.5899], dtype=torch.float64)
5 tensor([-0.6546], dtype=torch.float64)
6 tensor([0.0219], dtype=torch.float64)
7 tensor([-0.0046], dtype=torch.float64)
8 tensor([-0.0272], dtype=torch.float64)
9 tensor([-0.8230], dtype=torch.float64)

またはworker_idを使ってシードをばらけさせつつ再現性を担保するやり方もあります。この方法だと、スクリプトの最初でシードを固定しておけば、workerごとに異なるランダム出力を得つつ実行結果が毎回同じようなDataLoaderが作成できます。

def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

np.random.seed(42)
dataset = RandomDataset()
dataloader = torch.utils.data.DataLoader(dataset,
                                         shuffle=True,
                                         num_workers=4,
                                         batch_size=1,
                                         worker_init_fn=worker_init_fn)    

参考

41
27
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
41
27

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?