再現性を担保するために脳死で最強のチェックポイントを作るためのメモ。
僕の環境では以下で全部ですが、他にも追加した方が良いものがあればコメントください。
全部盛り
とりあえず以下をコピペすれば再現性は最強になる。
保存
import random
import torch
import numpy as np
from apex import amp
model_to_save = model.module if hasattr(model, "module") else model # DataParallelを使用している場合はmodel.moduleを取り出す。
checkpoint = {
"model": model_to_save.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"amp": amp.state_dict(), # apex混合精度を使用する場合は必要
"random": random.getstate(),
"np_random": np.random.get_state(), # numpy.randomを使用する場合は必要
"torch": torch.get_rng_state(),
"torch_random": torch.random.get_rng_state(),
"cuda_random": torch.cuda.get_rng_state(), # gpuを使用する場合は必要
"cuda_random_all": torch.cuda.get_rng_state_all(), # 複数gpuを使用する場合は必要
}
torch.save(checkpoint, "checkpoint.bin")
読込
checkpoint = torch.load("checkpoint.bin")
if hasattr(model, "module"): # DataParallelを使用した場合
model.module.load_state_dict(checkpoint["model"])
else:
model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
amp.load_state_dict(checkpoint["amp"])
random.setstate(checkpoint["random"])
np.random.set_state(checkpoint["np_random"])
torch.set_rng_state(checkpoint["torch"])
torch.random.set_rng_state(checkpoint["torch_random"])
torch.cuda.set_rng_state(checkpoint["cuda_random"]) # gpuを使用する場合は必要
torch.cuda.torch.cuda.set_rng_state_all(checkpoint["cuda_random_all"]) # 複数gpuを使用する場合は必要
以下でそれぞれについて簡単に説明する。
モデルと学習関連
modelとoptimizerは当然保存。
schedulerも使用するなら保存しておいた方が吉。
保存
model_to_save = model.module if hasattr(model, "module") else model # DataParallelを使用している場合はmodel.moduleを取り出す。
checkpoint = {
"model": model_to_save.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
}
読込
if hasattr(model, "module"): # DataParallelを使用した場合
model.module.load_state_dict(checkpoint["model"])
else:
model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
混合精度(apex)
apexを使って混合精度で学習する場合は、スケーリング変数とかを保存する必要がある。
保存
checkpoint = {
"amp": amp.state_dict(),
}
読込
amp.load_state_dict(checkpoint["amp"])
乱数状態(Random State)
乱数状態はこれで完璧なはず。
使用しないモジュール関係は取り除いて良し。
保存
checkpoint = {
"random": random.getstate(),
"np_random": np.random.get_state(),
"torch": torch.get_rng_state(),
"torch_random": torch.random.get_rng_state(),
"cuda_random": torch.cuda.get_rng_state(), # gpuを使用する場合は必要
"cuda_random_all": torch.cuda.get_rng_state_all(), # 複数gpuを使用する場合は必要
}
読込
random.setstate(checkpoint["random"])
np.random.set_state(checkpoint["np_random"])
torch.set_rng_state(checkpoint["torch"])
torch.random.set_rng_state(checkpoint["torch_random"])
torch.cuda.set_rng_state(checkpoint["cuda_random"]) # gpuを使用する場合は必要
torch.cuda.torch.cuda.set_rng_state_all(checkpoint["cuda_random_all"]) # 複数gpuを使用する場合は必要