LoginSignup
0
2

More than 3 years have passed since last update.

深層学習(Pytorch)を用いた、Kaggle Titanic実践 PART 6 (Early Stopping)

Posted at

この章は深層学習 (アスキードワンゴ)https://github.com/Bjarten/early-stopping-pytorch を参考に書かれています。

十分な表現要領を持つ大きなモデルを訓練して、あるタスクに対して学習するとき訓練誤差は減少するが、検証誤差が再び増加し始めることがあります。
そこで、検証誤差が改善されるたびに、モデルを保存することにします。
一定のエポック数検証誤差が改善されない場合、学習は終了します。

コードは以下のようになります。

以下のコードをearlystopping.pyとして保存してください。

import numpy as np
import torch

class EarlyStopping:

    def __init__(self, patience=7, verbose=False):

        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.force_cancel = False

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score:
            self.counter += 1
            print(
                f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):

        if self.verbose:
            print(
                f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), 'models/model.pth')
        self.val_loss_min = val_loss
0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2