0
0

【Pytorch】DataLoaderの乱数を固定する方法(備忘録)

Posted at

内容

datasetからミニバッチを作成する際によく使用するDataLoader
しかし、記述次第では毎回バラバラなバッチを取り出してしまうことに。。。
そこで、本記事ではこの 取り出し方(乱数) を固定する方法を記述したい。

環境

  • Python 3.10.12
  • torch 2.0.1

実装

seed = 42

def seed_worker(worker_id):
    worker_seed = torch.initial_seed()

    # Seed other libraries with torch's seed
    random.seed(worker_seed)

    # Numpy seed must be between 0 and 2**32 - 1
    if worker_seed >= 2 ** 32:
        worker_seed = worker_seed % 2 ** 32
    np.random.seed(worker_seed)


train_loader = DataLoader(
                            train_dataset,
                            batch_size=100,
                            shuffle=True,
                            num_workers=2,
                            worker_init_fn=seed_worker, # <--- 追加
                            generator=torch.Generator().manual_seed(seed), # <--- 追加
                        )

基本的に、シード値を設定した上で、worker_init_fngeneratorの2つの引数を記述してもらうことで乱数を固定することができる。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0