LoginSignup
2
2

More than 1 year has passed since last update.

Pytorch 訓練時にベストモデルを更新する方法

Posted at

検証時の推定精度が最大値になるたびにbest_modelを更新する例を考えてみます。以下の例ではbest_modelmodel自体を参照しているため、modelが更新されるたびにbest_modelのパラメータも更新されてしまいます。これではダメです。

model = model.eval()
best_accuracy = 0
...
for epoch in range(100):
    for idx, data in enumerate(data_loader):
                    ...
    if cur_accuracy > best_accuracy:
        best_model = model
        best_accuracy = cur_accuracy
torch.save(best_model.state_dict(), 'model.pt')

したがって以下のようにdeepcopyを使いましょう。

import copy
...
model = model.eval()
best_accuracy = 0
...
for epoch in range(100):
    for idx, data in enumerate(data_loader):
                    ...
    if cur_accuracy > best_accuracy:
        best_model = copy.deepcopy(model)
        best_accuracy = cur_accuracy
torch.save(best_model.state_dict(), 'model.pt')

もしくは、最高精度が更新されるたびにモデルを保存してしまうのもありです。

model = model.eval()
best_accuracy = 0
...
for epoch in range(100):
    for idx, data in enumerate(data_loader):
                    ...
    if cur_accuracy > best_accuracy:
        torch.save(model.state_dict(), 'model.pt')
        best_accuracy = cur_accuracy

参考

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2