3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Pytorchで乱数生成状態の保存

Last updated at Posted at 2021-02-22

例えば、一定時間でジョブが停止してしまうような計算サーバーで、
一定時間内に終わらない学習を行う場合、チェックポイントをとり、
新しいジョブを投げてチェックポイントから学習を再開することになる。

この時、例えはDataLoadershuffle=Trueに設定している場合等は、
ジョブ再開時に乱数が初期化されてしまい学習の再現性が担保されなくなってしまう。

本記事には乱数の状態を保存することで学習再開時の再現性を担保する方法を記す。

PytorchにおけるRandom Stateの保存

Pytorchにおいて乱数の生成状態は次のように一時保存と読み出しできる。
(Python標準モジュールのrandomを使用しない場合は削っても問題ないです。)

import random
import torch

def save_random_state(file_path):
    random_state = random.getstate()
    torch_random_state = torch.random.get_rng_state()
    # GPUを使用する場合 torch.cuda.get_rng_state() も保存&ロードする必要あり。
    torch.save({"random_state":random_state, "torch_random_state":torch_random_state}, file_path)

def load_random_state(file_path):
    state = torch.load(file_path)
    random.setstate(state["random_state"])
    torch.random.set_rng_state(state["torch_random_state"])

検証

これは以下のようなコードで検証できる。

import random
import torch

def save_random_state(file_path):
    random_state = random.getstate()
    torch_random_state = torch.random.get_rng_state()
    torch.save({"random_state":random_state, "torch_random_state":torch_random_state}, file_path)

def load_random_state(file_path):
    state = torch.load(file_path)
    random.setstate(state["random_state"])
    torch.random.set_rng_state(state["torch_random_state"])

class RandomIntDataset(torch.utils.data.Dataset):
    """
    選択されたインデックスとランダムな値を返すだけのデータセット
    """
    def __init__(self):
        self.length = 10

    def __len__(self):
        return self.length

    def __getitem__(self, i):
        return i, random.randint(0,9)

if __name__ == '__main__':
    print("ランダム生成状態の保存")
    save_random_state("random_state.pt")

    rand_dataset = RandomIntDataset()
    data_loader = torch.utils.data.DataLoader(rand_dataset, batch_size=4, shuffle=True, num_workers=4)
    for batch in data_loader:
        print(batch[0].tolist(), batch[1].tolist())

    for batch in data_loader:
        print(batch[0].tolist(), batch[1].tolist())

    print("ランダム生成状態の読み出し")
    load_random_state("random_state.pt")
    for batch in data_loader:
        print(batch[0].tolist(), batch[1].tolist())

出力は以下のようになる。

ランダム生成状態の保存
[9, 6, 0, 7] [1, 1, 2, 0]
[4, 8, 5, 1] [5, 2, 1, 5]
[2, 3] [5, 1]
[1, 0, 2, 4] [9, 4, 6, 9]
[7, 3, 6, 9] [7, 6, 4, 5]
[5, 8] [4, 1]
ランダム生成状態の読み出し
[9, 6, 0, 7] [1, 1, 2, 0]
[4, 8, 5, 1] [5, 2, 1, 5]
[2, 3] [5, 1]

左側の配列は、Datasetのインデックス, 右側の配列はDataset内部で生成された乱数である。
保存後と読出し後で、無事同じ乱数を生成できていることがわかる。

3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?