基本的な用法
torch.save
PyTorchで学習させたmodelの保存にはtorch.save
を用います。
import torch
# https://pytorch.org/docs/main/generated/torch.save.html
# Save to file
x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
torch.save(x, "tensor.pt")
上記を実行するとカレントディレクトリにtensor.pt
というファイルが保存されます。
torch.load
保存させたmodelの読み込みにはtorch.load
を用います。下記が基本的な実行例です。
model = torch.load("tensor.pt", weights_only=True)
print(type(model), model)
・実行結果
<class 'torch.Tensor'> tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
CPU環境で呼び出す場合はtorch.load
の引数にmap_location=torch.device("cpu")
を入れて実行します。