17
16

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Pytorchの再現性最強チェックポイントの作り方メモ

Last updated at Posted at 2022-01-20

再現性を担保するために脳死で最強のチェックポイントを作るためのメモ。
僕の環境では以下で全部ですが、他にも追加した方が良いものがあればコメントください。

全部盛り

とりあえず以下をコピペすれば再現性は最強になる。

保存

import random
import torch
import numpy as np
from apex import amp

model_to_save = model.module if hasattr(model, "module") else model # DataParallelを使用している場合はmodel.moduleを取り出す。
checkpoint = {
    "model": model_to_save.state_dict(),
    "optimizer": optimizer.state_dict(),
    "scheduler": scheduler.state_dict(),
    "amp": amp.state_dict(), # apex混合精度を使用する場合は必要
    "random": random.getstate(),
    "np_random": np.random.get_state(), # numpy.randomを使用する場合は必要
    "torch": torch.get_rng_state(),
    "torch_random": torch.random.get_rng_state(),
    "cuda_random": torch.cuda.get_rng_state(), # gpuを使用する場合は必要
    "cuda_random_all": torch.cuda.get_rng_state_all(), # 複数gpuを使用する場合は必要
}

torch.save(checkpoint, "checkpoint.bin")

読込

checkpoint = torch.load("checkpoint.bin") 

if hasattr(model, "module"):  # DataParallelを使用した場合
    model.module.load_state_dict(checkpoint["model"])
else:
    model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
amp.load_state_dict(checkpoint["amp"])
random.setstate(checkpoint["random"])
np.random.set_state(checkpoint["np_random"])
torch.set_rng_state(checkpoint["torch"])
torch.random.set_rng_state(checkpoint["torch_random"])
torch.cuda.set_rng_state(checkpoint["cuda_random"]) # gpuを使用する場合は必要
torch.cuda.torch.cuda.set_rng_state_all(checkpoint["cuda_random_all"]) # 複数gpuを使用する場合は必要

以下でそれぞれについて簡単に説明する。

モデルと学習関連

modelとoptimizerは当然保存。
schedulerも使用するなら保存しておいた方が吉。

保存

model_to_save = model.module if hasattr(model, "module") else model # DataParallelを使用している場合はmodel.moduleを取り出す。
checkpoint = {
    "model": model_to_save.state_dict(),
    "optimizer": optimizer.state_dict(),
    "scheduler": scheduler.state_dict(),
}

読込

if hasattr(model, "module"):  # DataParallelを使用した場合
    model.module.load_state_dict(checkpoint["model"])
else:
    model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])

混合精度(apex)

apexを使って混合精度で学習する場合は、スケーリング変数とかを保存する必要がある。

保存

checkpoint = {
    "amp": amp.state_dict(),
}

読込

amp.load_state_dict(checkpoint["amp"])

乱数状態(Random State)

乱数状態はこれで完璧なはず。
使用しないモジュール関係は取り除いて良し。

保存

checkpoint = {
    "random": random.getstate(),
    "np_random": np.random.get_state(),
    "torch": torch.get_rng_state(),
    "torch_random": torch.random.get_rng_state(),
    "cuda_random": torch.cuda.get_rng_state(), # gpuを使用する場合は必要
    "cuda_random_all": torch.cuda.get_rng_state_all(), # 複数gpuを使用する場合は必要
}

読込

random.setstate(checkpoint["random"])
np.random.set_state(checkpoint["np_random"])
torch.set_rng_state(checkpoint["torch"])
torch.random.set_rng_state(checkpoint["torch_random"])
torch.cuda.set_rng_state(checkpoint["cuda_random"]) # gpuを使用する場合は必要
torch.cuda.torch.cuda.set_rng_state_all(checkpoint["cuda_random_all"]) # 複数gpuを使用する場合は必要
17
16
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
17
16

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?