内容
datasetからミニバッチを作成する際によく使用するDataLoader
。
しかし、記述次第では毎回バラバラなバッチを取り出してしまうことに。。。
そこで、本記事ではこの 取り出し方(乱数) を固定する方法を記述したい。
環境
- Python 3.10.12
- torch 2.0.1
実装
seed = 42
def seed_worker(worker_id):
worker_seed = torch.initial_seed()
# Seed other libraries with torch's seed
random.seed(worker_seed)
# Numpy seed must be between 0 and 2**32 - 1
if worker_seed >= 2 ** 32:
worker_seed = worker_seed % 2 ** 32
np.random.seed(worker_seed)
train_loader = DataLoader(
train_dataset,
batch_size=100,
shuffle=True,
num_workers=2,
worker_init_fn=seed_worker, # <--- 追加
generator=torch.Generator().manual_seed(seed), # <--- 追加
)
基本的に、シード値を設定した上で、worker_init_fn
とgenerator
の2つの引数を記述してもらうことで乱数を固定することができる。