LoginSignup
0
0

More than 1 year has passed since last update.

PyTorch の dataloader の shuffle 機能に再現性をもたせる方法

Last updated at Posted at 2022-07-02

はじめに

複数の実験条件で比較検証を行う場合、比較したい条件以外は同じ条件で実験しなければなりませんが、dataloader の shuffle 機能を使うと実行ごとに毎度変わったデータの読み込み順になってしまいます。これでは学習モデルの再現性を保つことができません。
今回は、データの読み込み順にランダム性をもたせつつ、shuffle機能の再現性をseedで担保する解決策を残しておきます。

解決方法

torch.manual_seed(seed)を加えます。

train.py
seed = 20220703

'''
追加コード: shuffle の再現性を担保する
'''
torch.manual_seed(seed) 

train_set = mydataset.MyDatasets(input_dir)
train_loader = torch.utils.data.DataLoader(train_set,batch_size=n_batch,shuffle=True,num_workers=4,pin_memory=True,drop_last=True)


for epoch in range(n_epoch):
    ...
    for itr, data_dict in enumerate(train_loader):
        ...

テストコード

実験を再現するためのコードも載せておきます。

make_dataset.py
'''
画像データを作成
'''

import os
import random
import cv2

outdir = './traindata'
os.makedirs(outdir, exist_ok=True)
n_data = 20 # 20枚生成

seed = 20220703
random.seed(seed)

for i in range(n_data):
    img = 255 * np.random.rand(256,256,3)
    cv2.imwrite('%s/%04d.jpg' % (outdir,i), img)

mydataset.py
'''
自作のdataloaderを作成
 ・自作のtransformerを使って、ランダムに反転させる処理を追加
'''
import torch
from glob import glob
import os
import cv2

class MyDatasets(torch.utils.data.Dataset):
    def __init__(self,input_dir,transform=None):
        self.input_paths = glob('%s/*.jpg' % input_dir)
        self.transform = transform
        
    def __len__(self):
        return len(self.input_paths)

    def __getitem__(self, idx):
        data_dict = {}
        input_path = self.input_paths[idx]
        data_dict['imgname'] = os.path.basename(input_path)
        data_dict['img'] = torch.from_numpy(cv2.imread(input_path)[...,::-1].copy()).permute(2,0,1)
        data_dict["flip"] = False

        # random flip
        if self.transform:
            data_dict = self.transform(data_dict)

        return data_dict
mytransform.py
'''
自作のtransformerを作成
 ・データ拡張のため、画像を左右反転させる処理をランダムに行う
'''
import torch
import random

class MyTransforms():
    def __init__(self):
        pass
    
    def __call__(self, data_dict):
        if random.random() < 0.5:
            data_dict["flip"] = True
            data_dict["img"] = torch.flip(data_dict["img"],[2])

        return data_dict
train.py
'''
学習コード
 ・データの読み込みの様子を確認するだけ
'''
import torch
import mytransform
import mydataset

seed = 20220703
n_epoch = 3
n_batch = 4
num_workers = 2

input_dir='./traindata'

torch.manual_seed(seed)
transform = mytransform.MyTransforms()
train_set = mydataset.MyDatasets(input_dir,transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=n_batch,shuffle=True,pin_memory=True,drop_last=True)


for epoch in range(n_epoch):
    print('epoch: %d' % (epoch+1))
    for itr, data_dict in enumerate(train_loader):
        print(data_dict["imgname"])
        print(data_dict["flip"])

実行結果

1回目の実行結果

epoch: 1
['0009.jpg', '0008.jpg', '0000.jpg', '0015.jpg']
tensor([False, False, False, False])
['0017.jpg', '0005.jpg', '0010.jpg', '0014.jpg']
tensor([ True,  True,  True, False])
['0001.jpg', '0018.jpg', '0011.jpg', '0003.jpg']
tensor([False, False,  True, False])
['0016.jpg', '0019.jpg', '0012.jpg', '0002.jpg']
tensor([ True,  True, False, False])
['0006.jpg', '0013.jpg', '0007.jpg', '0004.jpg']
tensor([ True, False, False, False])
epoch: 2
['0007.jpg', '0004.jpg', '0017.jpg', '0002.jpg']
tensor([ True, False, False,  True])
['0006.jpg', '0010.jpg', '0008.jpg', '0018.jpg']
tensor([False, False, False, False])
['0016.jpg', '0009.jpg', '0013.jpg', '0005.jpg']
tensor([ True, False, False,  True])
['0001.jpg', '0003.jpg', '0015.jpg', '0019.jpg']
tensor([ True, False,  True, False])
['0011.jpg', '0000.jpg', '0014.jpg', '0012.jpg']
tensor([False,  True, False,  True])
epoch: 3
['0008.jpg', '0003.jpg', '0011.jpg', '0015.jpg']
tensor([False,  True, False,  True])
['0014.jpg', '0019.jpg', '0000.jpg', '0006.jpg']
tensor([False,  True, False, False])
['0012.jpg', '0009.jpg', '0007.jpg', '0016.jpg']
tensor([False,  True, False, False])
['0018.jpg', '0002.jpg', '0004.jpg', '0017.jpg']
tensor([ True, False, False,  True])
['0010.jpg', '0013.jpg', '0001.jpg', '0005.jpg']
tensor([False,  True,  True, False])

各エポックごとに読み込むデータにランダム性を持たせながら...

2回目の実行結果

epoch: 1
['0009.jpg', '0008.jpg', '0000.jpg', '0015.jpg']
tensor([False, False, False, False])
['0017.jpg', '0005.jpg', '0010.jpg', '0014.jpg']
tensor([ True,  True,  True, False])
['0001.jpg', '0018.jpg', '0011.jpg', '0003.jpg']
tensor([False, False,  True, False])
['0016.jpg', '0019.jpg', '0012.jpg', '0002.jpg']
tensor([ True,  True, False, False])
['0006.jpg', '0013.jpg', '0007.jpg', '0004.jpg']
tensor([ True, False, False, False])
epoch: 2
['0007.jpg', '0004.jpg', '0017.jpg', '0002.jpg']
tensor([ True, False, False,  True])
['0006.jpg', '0010.jpg', '0008.jpg', '0018.jpg']
tensor([False, False, False, False])
['0016.jpg', '0009.jpg', '0013.jpg', '0005.jpg']
tensor([ True, False, False,  True])
['0001.jpg', '0003.jpg', '0015.jpg', '0019.jpg']
tensor([ True, False,  True, False])
['0011.jpg', '0000.jpg', '0014.jpg', '0012.jpg']
tensor([False,  True, False,  True])
epoch: 3
['0008.jpg', '0003.jpg', '0011.jpg', '0015.jpg']
tensor([False,  True, False,  True])
['0014.jpg', '0019.jpg', '0000.jpg', '0006.jpg']
tensor([False,  True, False, False])
['0012.jpg', '0009.jpg', '0007.jpg', '0016.jpg']
tensor([False,  True, False, False])
['0018.jpg', '0002.jpg', '0004.jpg', '0017.jpg']
tensor([ True, False, False,  True])
['0010.jpg', '0013.jpg', '0001.jpg', '0005.jpg']
tensor([False,  True,  True, False])

2回目でも1回目と同様の読み込み結果を得ることができました。

疑問点と補足

今回、シードを固定した部分は torch.manual_seed(seed) だけなのにも関わらず、なぜか mytransform.py で使用している random.random() の値も再現性が保たれているところが疑問として残っています。
torch.manual_seed() を使用したら random.random() の値も固定されるものなのかと思い、以下のような実験を行いましたが、

rand_test.py
import random
import torch
seed = 20220703
torch.manual_seed(seed)
print(random.random())

...

output:
1回目の実行結果:
0.9445797320532637
2回目の実行結果:
0.7288089413186923

のように再現性は保たれていません...
気づいたことがありましたら、ご教授いただけると幸いです。

なお、dataloaderの__getitem__内で乱数を使用し、workerごとに異なるランダム出力を得るためには、https://qiita.com/yagays/items/d413787a78aae825dbd3 で書かれているようにworker_init_fnを定義する必要があります。
今回はshuffle以外でランダム性を持たせたいデータ拡張部分をtransform内に書いたのでその問題は起こっていないようです。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0