この章は深層学習 (アスキードワンゴ) と https://github.com/Bjarten/early-stopping-pytorch を参考に書かれています。
十分な表現要領を持つ大きなモデルを訓練して、あるタスクに対して学習するとき訓練誤差は減少するが、検証誤差が再び増加し始めることがあります。
そこで、検証誤差が改善されるたびに、モデルを保存することにします。
一定のエポック数検証誤差が改善されない場合、学習は終了します。
コードは以下のようになります。
以下のコードをearlystopping.pyとして保存してください。
import numpy as np
import torch
class EarlyStopping:
def __init__(self, patience=7, verbose=False):
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.force_cancel = False
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score:
self.counter += 1
print(
f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0
def save_checkpoint(self, val_loss, model):
if self.verbose:
print(
f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
torch.save(model.state_dict(), 'models/model.pth')
self.val_loss_min = val_loss