例えば、一定時間でジョブが停止してしまうような計算サーバーで、
一定時間内に終わらない学習を行う場合、チェックポイントをとり、
新しいジョブを投げてチェックポイントから学習を再開することになる。
この時、例えはDataLoader
でshuffle=True
に設定している場合等は、
ジョブ再開時に乱数が初期化されてしまい学習の再現性が担保されなくなってしまう。
本記事には乱数の状態を保存することで学習再開時の再現性を担保する方法を記す。
PytorchにおけるRandom Stateの保存
Pytorchにおいて乱数の生成状態は次のように一時保存と読み出しできる。
(Python標準モジュールのrandom
を使用しない場合は削っても問題ないです。)
import random
import torch
def save_random_state(file_path):
random_state = random.getstate()
torch_random_state = torch.random.get_rng_state()
# GPUを使用する場合 torch.cuda.get_rng_state() も保存&ロードする必要あり。
torch.save({"random_state":random_state, "torch_random_state":torch_random_state}, file_path)
def load_random_state(file_path):
state = torch.load(file_path)
random.setstate(state["random_state"])
torch.random.set_rng_state(state["torch_random_state"])
検証
これは以下のようなコードで検証できる。
import random
import torch
def save_random_state(file_path):
random_state = random.getstate()
torch_random_state = torch.random.get_rng_state()
torch.save({"random_state":random_state, "torch_random_state":torch_random_state}, file_path)
def load_random_state(file_path):
state = torch.load(file_path)
random.setstate(state["random_state"])
torch.random.set_rng_state(state["torch_random_state"])
class RandomIntDataset(torch.utils.data.Dataset):
"""
選択されたインデックスとランダムな値を返すだけのデータセット
"""
def __init__(self):
self.length = 10
def __len__(self):
return self.length
def __getitem__(self, i):
return i, random.randint(0,9)
if __name__ == '__main__':
print("ランダム生成状態の保存")
save_random_state("random_state.pt")
rand_dataset = RandomIntDataset()
data_loader = torch.utils.data.DataLoader(rand_dataset, batch_size=4, shuffle=True, num_workers=4)
for batch in data_loader:
print(batch[0].tolist(), batch[1].tolist())
for batch in data_loader:
print(batch[0].tolist(), batch[1].tolist())
print("ランダム生成状態の読み出し")
load_random_state("random_state.pt")
for batch in data_loader:
print(batch[0].tolist(), batch[1].tolist())
出力は以下のようになる。
ランダム生成状態の保存
[9, 6, 0, 7] [1, 1, 2, 0]
[4, 8, 5, 1] [5, 2, 1, 5]
[2, 3] [5, 1]
[1, 0, 2, 4] [9, 4, 6, 9]
[7, 3, 6, 9] [7, 6, 4, 5]
[5, 8] [4, 1]
ランダム生成状態の読み出し
[9, 6, 0, 7] [1, 1, 2, 0]
[4, 8, 5, 1] [5, 2, 1, 5]
[2, 3] [5, 1]
左側の配列は、Datasetのインデックス, 右側の配列はDataset内部で生成された乱数である。
保存後と読出し後で、無事同じ乱数を生成できていることがわかる。