検証時の推定精度が最大値になるたびにbest_model
を更新する例を考えてみます。以下の例ではbest_model
がmodel
自体を参照しているため、model
が更新されるたびにbest_model
のパラメータも更新されてしまいます。これではダメです。
model = model.eval()
best_accuracy = 0
...
for epoch in range(100):
for idx, data in enumerate(data_loader):
...
if cur_accuracy > best_accuracy:
best_model = model
best_accuracy = cur_accuracy
torch.save(best_model.state_dict(), 'model.pt')
したがって以下のようにdeepcopy
を使いましょう。
import copy
...
model = model.eval()
best_accuracy = 0
...
for epoch in range(100):
for idx, data in enumerate(data_loader):
...
if cur_accuracy > best_accuracy:
best_model = copy.deepcopy(model)
best_accuracy = cur_accuracy
torch.save(best_model.state_dict(), 'model.pt')
もしくは、最高精度が更新されるたびにモデルを保存してしまうのもありです。
model = model.eval()
best_accuracy = 0
...
for epoch in range(100):
for idx, data in enumerate(data_loader):
...
if cur_accuracy > best_accuracy:
torch.save(model.state_dict(), 'model.pt')
best_accuracy = cur_accuracy