3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PytorchのDataLoaderのshuffleについて

Last updated at Posted at 2021-10-22

DataLoaderについて

PytorchのDataLoaderはよく機械学習モデルで使用されています。これを用いることで,エポック毎に異なる組み合わせや順番で ミニバッチ学習を行うことができます。しかし,それには再現性があるのでしょうか。今まであまり確認していなかったので確認してみることにします。

(また誤りや質問などありましたらコメントお願いいたします。
追記 : GPUなどを用いた実際の深層学習においては,
こちらが参考になると思います。
north_rewing様の記事『Pytorchの再現性に関して』
makaishi2様の記事『PyTorch CNNモデル再現性問題』
を参照してみてください。)

深層学習における再現性(seed)の重要性に関して

Takahiro Kuboさんのこちらのスライドには強化学習においてseedが異なるだけで汎化性能が大きく異なる論文を紹介されており, 汎化性能を担保したいモデルを運用したい場合はこの辺りの話は極めて重要となります。

結論

学習時のエポックのfor文が始まる前に

seed=42
torch.manual_seed(seed)
# for GPU
# torch.cuda.manual_seed(seed)

といった文言をつけましょう。

実験

import numpy as np
import torch

#サンプル
x = np.array([[1], [2], [3], [4], [5], [6]])
y = np.array([[10], [20], [30], [40], [50], [60]])
x = torch.tensor(x)
y = torch.tensor(y)
dataset = torch.utils.data.TensorDataset(x, y)
batch_size=3


data_loader1 = torch.utils.data.DataLoader(dataset, 
                       batch_size=batch_size, shuffle=True)
data_loader2 = torch.utils.data.DataLoader(dataset, 
                       batch_size=batch_size, shuffle=True)

print(f"--- dataloader{1}")
for _ in range(2):
    print(f"{_}番目のエポック")
    for i,(x,y) in enumerate(data_loader1):
        print(f"  {i+1}番目のバッチ  :  x:{x.tolist()}, y:{y.tolist()}")
print(f"--- dataloader{2}")
for _ in range(2):
    print(f"{_}番目のエポック")
    for i,(x,y) in enumerate(data_loader2):
        print(f"  {i+1}番目のバッチ  :  x:{x.tolist()}, y:{y.tolist()}")
--- dataloader1
0番目のエポック
  1番目のバッチ  :  x:[[3], [2], [1]], y:[[30], [20], [10]]
  2番目のバッチ  :  x:[[5], [4], [6]], y:[[50], [40], [60]]
1番目のエポック
  1番目のバッチ  :  x:[[1], [2], [5]], y:[[10], [20], [50]]
  2番目のバッチ  :  x:[[4], [3], [6]], y:[[40], [30], [60]]
--- dataloader2
0番目のエポック
  1番目のバッチ  :  x:[[6], [3], [2]], y:[[60], [30], [20]]
  2番目のバッチ  :  x:[[5], [1], [4]], y:[[50], [10], [40]]
1番目のエポック
  1番目のバッチ  :  x:[[1], [6], [4]], y:[[10], [60], [40]]
  2番目のバッチ  :  x:[[3], [2], [5]], y:[[30], [20], [50]]

dataloader1とdataloade2の結果が異なるため,再現性がありません...。shuffle=Falseではもちろん再現性がありますが,念の為,確かめてみましょう。


data_loader3 = torch.utils.data.DataLoader(dataset, 
                       batch_size=batch_size, shuffle=False)
data_loader4 = torch.utils.data.DataLoader(dataset, 
                       batch_size=batch_size, shuffle=False)
print(f"--- dataloader{3}")
for _ in range(2):
    print(f"{_}番目のエポック")
    for i,(x,y) in enumerate(data_loader3):
        print(f"  {i+1}番目のバッチ  :  x:{x.tolist()}, y:{y.tolist()}")
print(f"--- dataloader{4}")
for _ in range(2):
    print(f"{_}番目のエポック")
    for i,(x,y) in enumerate(data_loader4):
        print(f"  {i+1}番目のバッチ  :  x:{x.tolist()}, y:{y.tolist()}")
--- dataloader3
0番目のエポック
  1番目のバッチ  :  x:[[1], [2], [3]], y:[[10], [20], [30]]
  2番目のバッチ  :  x:[[4], [5], [6]], y:[[40], [50], [60]]
1番目のエポック
  1番目のバッチ  :  x:[[1], [2], [3]], y:[[10], [20], [30]]
  2番目のバッチ  :  x:[[4], [5], [6]], y:[[40], [50], [60]]
--- dataloader4
0番目のエポック
  1番目のバッチ  :  x:[[1], [2], [3]], y:[[10], [20], [30]]
  2番目のバッチ  :  x:[[4], [5], [6]], y:[[40], [50], [60]]
1番目のエポック
  1番目のバッチ  :  x:[[1], [2], [3]], y:[[10], [20], [30]]
  2番目のバッチ  :  x:[[4], [5], [6]], y:[[40], [50], [60]]

再現性ありました。
ではsheffle=Trueにしつつ再現性を保つ i.e.
dataloader1とdataloader2の結果を同じにするにはどのようにすれば良いのでしょうか。

試み1

dataloaderが定義されるときにseedを固定する.

torch.manual_seed(42)
data_loader_new1 = torch.utils.data.DataLoader(dataset, 
                       batch_size=batch_size, shuffle=True)
torch.manual_seed(42)
data_loader_new2 = torch.utils.data.DataLoader(dataset, 
                       batch_size=batch_size, shuffle=True)

print(f"--- dataloader{1}")
for _ in range(2):
    print(f"{_}番目のエポック")
    for i,(x,y) in enumerate(data_loader_new1):
        print(f"  {i+1}番目のバッチ  :  x:{x.tolist()}, y:{y.tolist()}")
print(f"--- dataloader{2}")
for _ in range(2):
    print(f"{_}番目のエポック")
    for i,(x,y) in enumerate(data_loader_new2):
        print(f"  {i+1}番目のバッチ  :  x:{x.tolist()}, y:{y.tolist()}")
--- dataloader1
0番目のエポック
  1番目のバッチ  :  x:[[5], [6], [2]], y:[[50], [60], [20]]
  2番目のバッチ  :  x:[[1], [4], [3]], y:[[10], [40], [30]]
1番目のエポック
  1番目のバッチ  :  x:[[5], [6], [2]], y:[[50], [60], [20]]
  2番目のバッチ  :  x:[[3], [4], [1]], y:[[30], [40], [10]]
--- dataloader2
0番目のエポック
  1番目のバッチ  :  x:[[3], [2], [1]], y:[[30], [20], [10]]
  2番目のバッチ  :  x:[[5], [4], [6]], y:[[50], [40], [60]]
1番目のエポック
  1番目のバッチ  :  x:[[1], [2], [5]], y:[[10], [20], [50]]
  2番目のバッチ  :  x:[[4], [3], [6]], y:[[40], [30], [60]]

だめでした。

試み.2

dataloaderから取り出す際にseedを固定する.

print(f"--- dataloader{1}")
torch.manual_seed(42)
for _ in range(2):
    print(f"{_}番目のエポック")
    for i,(x,y) in enumerate(data_loader1):
        print(f"  {i+1}番目のバッチ  :  x:{x.tolist()}, y:{y.tolist()}")
print(f"--- dataloader{2}")
torch.manual_seed(42)
for _ in range(2):
    print(f"{_}番目のエポック")
    for i,(x,y) in enumerate(data_loader2):
        print(f"  {i+1}番目のバッチ  :  x:{x.tolist()}, y:{y.tolist()}")
--- dataloader1
0番目のエポック
  1番目のバッチ  :  x:[[5], [6], [2]], y:[[50], [60], [20]]
  2番目のバッチ  :  x:[[1], [4], [3]], y:[[10], [40], [30]]
1番目のエポック
  1番目のバッチ  :  x:[[5], [6], [2]], y:[[50], [60], [20]]
  2番目のバッチ  :  x:[[3], [4], [1]], y:[[30], [40], [10]]
--- dataloader2
0番目のエポック
  1番目のバッチ  :  x:[[5], [6], [2]], y:[[50], [60], [20]]
  2番目のバッチ  :  x:[[1], [4], [3]], y:[[10], [40], [30]]
1番目のエポック
  1番目のバッチ  :  x:[[5], [6], [2]], y:[[50], [60], [20]]
  2番目のバッチ  :  x:[[3], [4], [1]], y:[[30], [40], [10]]

このままでは,全てのエポックで同じ組み合わせを持ってきているため,shuffleしている感じを確認できません。現状何とも言えないため,エポック数を増やしてみましょう。

print(f"--- dataloader{1}")
torch.manual_seed(42)
for _ in range(5):
    print(f"{_}番目のエポック")
    for i,(x,y) in enumerate(data_loader1):
        print(f"  {i+1}番目のバッチ  :  x:{x.tolist()}, y:{y.tolist()}")
print(f"--- dataloader{2}")
torch.manual_seed(42)
for _ in range(5):
    print(f"{_}番目のエポック")
    for i,(x,y) in enumerate(data_loader2):
        print(f"  {i+1}番目のバッチ  :  x:{x.tolist()}, y:{y.tolist()}")
--- dataloader1
0番目のエポック
  1番目のバッチ  :  x:[[5], [6], [2]], y:[[50], [60], [20]]
  2番目のバッチ  :  x:[[1], [4], [3]], y:[[10], [40], [30]]
1番目のエポック
  1番目のバッチ  :  x:[[5], [6], [2]], y:[[50], [60], [20]]
  2番目のバッチ  :  x:[[3], [4], [1]], y:[[30], [40], [10]]
2番目のエポック
  1番目のバッチ  :  x:[[3], [2], [1]], y:[[30], [20], [10]]
  2番目のバッチ  :  x:[[5], [4], [6]], y:[[50], [40], [60]]
3番目のエポック
  1番目のバッチ  :  x:[[1], [2], [5]], y:[[10], [20], [50]]
  2番目のバッチ  :  x:[[4], [3], [6]], y:[[40], [30], [60]]
4番目のエポック
  1番目のバッチ  :  x:[[6], [3], [2]], y:[[60], [30], [20]]
  2番目のバッチ  :  x:[[5], [1], [4]], y:[[50], [10], [40]]
--- dataloader2
0番目のエポック
  1番目のバッチ  :  x:[[5], [6], [2]], y:[[50], [60], [20]]
  2番目のバッチ  :  x:[[1], [4], [3]], y:[[10], [40], [30]]
1番目のエポック
  1番目のバッチ  :  x:[[5], [6], [2]], y:[[50], [60], [20]]
  2番目のバッチ  :  x:[[3], [4], [1]], y:[[30], [40], [10]]
2番目のエポック
  1番目のバッチ  :  x:[[3], [2], [1]], y:[[30], [20], [10]]
  2番目のバッチ  :  x:[[5], [4], [6]], y:[[50], [40], [60]]
3番目のエポック
  1番目のバッチ  :  x:[[1], [2], [5]], y:[[10], [20], [50]]
  2番目のバッチ  :  x:[[4], [3], [6]], y:[[40], [30], [60]]
4番目のエポック
  1番目のバッチ  :  x:[[6], [3], [2]], y:[[60], [30], [20]]
  2番目のバッチ  :  x:[[5], [1], [4]], y:[[50], [10], [40]]

これで確かに,再現性が保てることがわかりました。

この辺りを含めdropoutや深層学習の初期値に関してもseedを気にする必要があります。
今後機会があればこの辺りもまとめてみたいと思います。

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?