PyTorchでアーリーストップを行うときのコードをすぐに忘れるのでメモ.
PyTorchでは自由度が高いので自分で実装する.
early_stop = 5
best_score = 0.
not_improving = 0
for epoch in range(n_epochs):
loss_train = train_net(train_loader)
loss_valid, score = valid_net(valid_loader)
print(f"Epoch: {epoch+1}, lr: {optimizer.param_groups[0]['lr']:.7f}, loss_train: {loss_train:.5f}, loss_valid: {loss_valid:.5f}, score: {score:.6f}")
not_improving += 1
if score > best_score:
best_score = score
not_improving = 0
best_loss = loss_valid
if not_improving == early_stop:
print("Early Stopping...")
break