24
Help us understand the problem. What are the problem?

More than 3 years have passed since last update.

posted at

updated at

Pytorchにおいて自作Dataset内で乱数を使うときの注意点

Pytorchで少し凝った入出力のNNを作成するときには、既存のDatasetで対応できないことがあります。その際にはtorch.utils.data.datasetを継承する形で自作のDatasetを作成するのですが、そこで乱数を使っていると意図しない挙動をするケースがあったので、書き残しておきます。乱数と再現性は本当に難しいですね……。

環境

  • Python 3.6.3
  • Pytorch 0.4.1
  • macOS High Sierra 10.13.6

問題

ここでは例として、以下のようにnp.random.uniform(-1, 1)の値を返すDatasetを考えます。想定としては、このDatasetから値を取り出すたびに、-1から1の間の値がランダムに返ってきてほしいわけです。

import torch
from torch.utils.data.dataset import Dataset

class RandomDataset(Dataset):
    def __init__(self):
        pass

    def __getitem__(self, index):
        return np.random.uniform(-1, 1)

    def __len__(self):
        return 10

ではこれを単純にDataLoaderで取り出すとどうなるでしょうか。

dataset = RandomDataset()
dataloader = torch.utils.data.DataLoader(dataset,
                                         shuffle=True,
                                         num_workers=4,
                                         batch_size=1)
In [ ]: for i, n in enumerate(dataloader):
   ...:     print(i, n)

0 tensor([-0.4910], dtype=torch.float64)
1 tensor([-0.4910], dtype=torch.float64)
2 tensor([-0.4910], dtype=torch.float64)
3 tensor([-0.4910], dtype=torch.float64)
4 tensor([-0.5297], dtype=torch.float64)
5 tensor([-0.5297], dtype=torch.float64)
6 tensor([-0.5297], dtype=torch.float64)
7 tensor([-0.5297], dtype=torch.float64)
8 tensor([-0.2516], dtype=torch.float64)
9 tensor([-0.2516], dtype=torch.float64)

毎回ランダムにDatasetから取得していたはずが、どうも一定の回数だけ同じ値が帰ってきています。よく見るとnum_workersの数だけ重複しています。

DataLoaderの仕様

この問題は、公式ドキュメントでも触れられています。

これによると、workerの分だけforkするときに同一のnumpyのRandom Stateの状態がコピーされるようです。そのため、Datasetの中でランダムな処理を書いていたとしても、それぞれのworkderで同じ結果が返されていたわけです。

対応方法

Dataset側を変えずに対応するにはDataLoaderworker_init_fnパラメータでworkerの初期化時にシードを指定します。一番単純な方法としては、np.random.seed()をworkerごとに実行する方法です。

np.random generates the same random numbers for each data batch · Issue #5059 · pytorch/pytorch

dataset = RandomDataset()
dataloader = torch.utils.data.DataLoader(dataset,
                                         shuffle=True,
                                         num_workers=4,
                                         batch_size=1,
                                         worker_init_fn=lambda x: np.random.seed())

これで、Datasetから取り出すたびに違う値を得ることができました。ただし、これは完全にランダムになってしまうので、DataLoaderから値を取り出す操作をするたびに異なる結果になります。

In [ ]: for i, n in enumerate(dataloader):
   ...:     print(i, n)

0 tensor([-0.0622], dtype=torch.float64)
1 tensor([0.2198], dtype=torch.float64)
2 tensor([-0.5684], dtype=torch.float64)
3 tensor([0.7973], dtype=torch.float64)
4 tensor([-0.5899], dtype=torch.float64)
5 tensor([-0.6546], dtype=torch.float64)
6 tensor([0.0219], dtype=torch.float64)
7 tensor([-0.0046], dtype=torch.float64)
8 tensor([-0.0272], dtype=torch.float64)
9 tensor([-0.8230], dtype=torch.float64)

またはworker_idを使ってシードをばらけさせつつ再現性を担保するやり方もあります。この方法だと、スクリプトの最初でシードを固定しておけば、workerごとに異なるランダム出力を得つつ実行結果が毎回同じようなDataLoaderが作成できます。

def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

np.random.seed(42)
dataset = RandomDataset()
dataloader = torch.utils.data.DataLoader(dataset,
                                         shuffle=True,
                                         num_workers=4,
                                         batch_size=1,
                                         worker_init_fn=worker_init_fn)    

参考

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Sign upLogin
24
Help us understand the problem. What are the problem?