73
58

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorchでEarlyStoppingを実装する

Last updated at Posted at 2021-04-10

#0.はじめに
今までKerasを使っていた人がいざpytorchを使うってなった際に、Kerasでは当たり前にあった機能がpytorchでは無い!というようなことに困惑することがあります。
KerasではcallbackとしてEarlyStoppingの機能が備わっていますが、Pytorchではデフォルトでこの機能は存在せず、自分で実装する必要があります。
今回はそれを実装したので共有しておきます。

参考:KerasのEarlyStopping

#1.EarlyStoppingって何なのか?
本記事を見るような方はすでにご存じかもしれませんが、一応説明します。

・そもそも何Epoch学習を回せばいいのかなんて初見ではわからない
・Epoch数を重ねるたびにlossは下がるが、訓練データに過学習する可能性がある
・せっかくlossが下がっても、そのまま放置しすぎるとlossが上がっていってしまうケースがある

こんなの時に便利なのがこのEarlyStoppingなのである。
★とりあえずEpoch数は大きく設定しておけばいい(EarlyStoppingが止めてくれるから)
★学習が収束した時に自動的に学習を止めてくれる(過学習防止)

#2.EarlyStoppingクラスを作成する

プログラム的には
・何回lossの最小値を更新しなかったら学習をやめるか?を決めて(patience)
・監視しているlossが最低値を更新できない数をカウントし(counter)
・監視しているlossが最低値を更新したときだけ学習済モデルを保存しておき、そのlossを記録(checkpoint)
・監視しているlossが設定数だけ最低値を更新できない場合に学習ループを抜ける(early_stop )
これらを実装すればいいだけである。

class EarlyStopping:
    """earlystoppingクラス"""
    
    def __init__(self, patience=5, verbose=False, path='checkpoint_model.pth'):
        """引数:最小値の非更新数カウンタ、表示設定、モデル格納path"""
        
        self.patience = patience    #設定ストップカウンタ
        self.verbose = verbose      #表示の有無
        self.counter = 0            #現在のカウンタ値
        self.best_score = None      #ベストスコア
        self.early_stop = False     #ストップフラグ
        self.val_loss_min = np.Inf   #前回のベストスコア記憶用
        self.path = path             #ベストモデル格納path
        
    def __call__(self, val_loss, model):
        """
        特殊(call)メソッド
        実際に学習ループ内で最小lossを更新したか否かを計算させる部分
        """
        score = -val_loss

        if self.best_score is None:  #1Epoch目の処理
            self.best_score = score   #1Epoch目はそのままベストスコアとして記録する
            self.checkpoint(val_loss, model)  #記録後にモデルを保存してスコア表示する
        elif score < self.best_score:  # ベストスコアを更新できなかった場合
            self.counter += 1   #ストップカウンタを+1
            if self.verbose:  #表示を有効にした場合は経過を表示
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')  #現在のカウンタを表示する 
            if self.counter >= self.patience:  #設定カウントを上回ったらストップフラグをTrueに変更
                self.early_stop = True
        else:  #ベストスコアを更新した場合
            self.best_score = score  #ベストスコアを上書き
            self.checkpoint(val_loss, model)  #モデルを保存してスコア表示
            self.counter = 0  #ストップカウンタリセット

    def checkpoint(self, val_loss, model):
        '''ベストスコア更新時に実行されるチェックポイント関数'''
        if self.verbose:  #表示を有効にした場合は、前回のベストスコアからどれだけ更新したか?を表示
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)  #ベストモデルを指定したpathに保存
        self.val_loss_min = val_loss  #その時のlossを記録する

#3.CIFER10の画像分類に組み込んで使う
サンプルとして作成したクラスを使って、CIFER10の画像分類で使ってみます。
pytorch初心者もわかるように、1文1文に何やってるか?のコメントも書いてますので参考にしてください。

また、分かりやすいようにearlystoppingに関連する部分は★~~★でコメント書いています。

##3-1.まずはCIFER10の分類に使うCNNクラス作成まで


import torch
import torchvision
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torchvision import models
from torch import nn,optim
import numpy as np
import matplotlib.pyplot as plt

class EarlyStopping:
    """★上で書いてるので省略★"""

class Mynet(nn.Module):
    """自作CNN"""
    def __init__(self):
        super(Mynet, self).__init__()

        # 畳み込み層の定義
        self.conv1 = nn.Conv2d(3, 6, 5)  # コンボリューション1
        self.conv2 = nn.Conv2d(6, 16, 5) # コンボリューション2
        # 全結合層の定義
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 全結合
        self.fc2 = nn.Linear(120, 84) # 全結合
        self.fc3 = nn.Linear(84, 10)  # CIFAR10のクラス数が10なので出力を10にする
        #プーリング層の定義
        self.pool = nn.MaxPool2d(2, 2)  # maxプーリング

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # conv1~relu~pool
        x = self.pool(F.relu(self.conv2(x)))  # conv2~relu~pool
        x = x.view(-1, 16 * 5 * 5)  # 平坦化(1次元化する)
        x = F.relu(self.fc1(x))  # fc1~relu
        x = F.relu(self.fc2(x))  # fc2~relu
        x = self.fc3(x) #fc3~10分類
        return x

##3-2.学習ループ関数

def train_net(n_epochs, train_loader, net, optimizer_cls, loss_fn, device='cpu'):
    """学習ループ部分"""
    #★EarlyStoppingクラスのインスタンス化★
    earlystopping = EarlyStopping(patience=1, verbose=True) #検証なのでわざとpatience=1回にしている
    
    losses = [] #loss遷移の記録用
    net.to(device) #ネットワークをGPUへ
    
    for epoch in range(n_epochs):
        running_loss = 0.0 #loss初期化
        net.train() #netをtrainingモードに
        
        for index, data in enumerate(train_loader):
            inputs, labels = data  #画像とラベルを取り出す
            inputs = inputs.to(device)  #画像をGPUへ
            labels = labels.to(device)  #ラベルもGPUに 
            
            optimizer_cls.zero_grad()  #勾配を初期化
            outputs = net(inputs)  #ネットワークで予測
            loss = loss_fn(outputs, labels)  #loss計算
            loss.backward()  #逆伝番
            optimizer.step()  #勾配を更新
            running_loss += loss.item()  #バッチごとのlossを足していく
        
        losses.append(running_loss / index) #loss遷移を記録
        print("epoch", epoch, ": ", running_loss / index) #学習経過の表示
        
        #★毎エポックearlystoppingの判定をさせる★
        earlystopping((running_loss / index), net) #callメソッド呼び出し
        if earlystopping.early_stop: #ストップフラグがTrueの場合、breakでforループを抜ける
            print("Early Stopping!")
            break
            
    return losses

##3-3.main部分

#前処理部分(Tensor化、正規化)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        (0.5, 0.5, 0.5), # RGB 平均
        (0.5, 0.5, 0.5)# RGB 標準偏差
    )
])

#CIFAR10をダウンロードしてデータセット作成
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
#データローダーを作成
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)

#ネットワークをインスタンス化
model = Mynet()

#学習設定
criterion = nn.CrossEntropyLoss()  #損失関数をクロスエントロピーに
optimizer = optim.Adam(model.parameters(), lr=0.001)  #オプティマイザはAdam
epochs = 10000  #エポック数 ★とりあえず大きく設定しておけばいいので1万としておいた★
device = torch.device("cuda:0" if torch.cuda. is_available() else "cpu")  #デバイス(GPU or CPU)設定 

#学習開始(作成した学習関数を呼び出す)
losses = train_net(n_epochs=epochs,
                   train_loader=trainloader,
                   net=model,
                   optimizer_cls=optimizer,
                   loss_fn=criterion,
                   device=device
                  )

すると、以下の実行結果が表示される。

実行結果
epoch 0 :  1.0410267220779432
Validation loss decreased (inf --> 1.041027).  Saving model ...
epoch 1 :  1.0064437736963907
Validation loss decreased (1.041027 --> 1.006444).  Saving model ...
・
(略)
・
epoch 25 :  0.7393526346215752
Validation loss decreased (0.741965 --> 0.739353).  Saving model ...
epoch 26 :  0.7438237962901487
EarlyStopping counter: 1 out of 1
Early Stopping!

・26epoch目でloss値が「0.743」となっており、それまでのベストスコア(25epoch目)である「0.739」を更新できなかった
・patience=1としているので、1回でもベストスコアを更新できなければ学習ループを抜ける設定

##3-4.loss推移+保存したモデル確認
lossの推移を見てみると、確かに26epoch目でloss値が上向いていることが確認できる。
さらに保存した学習済モデルに関してはこの26epoch目ではなく、25epoch目が保存されている

plt.plot(losses, label='train_loss')
plt.xlabel('Epochs')
plt.ylabel('loss')
plt.legend()
plt.show()

##3-5.ベストモデルを適応させる
このままだとmodel変数には26epoch目の学習が反映されている為、保存したモデルを一度適応させる必要がある。
これで最高性能のモデルを使える。 ※以下省略

model.load_state_dict(torch.load('checkpoint_model.pth'))
実行結果
<All keys matched successfully>

#4. さいごに
久々に画像を題材に記事書いてみました。
PytorchはKerasより記載量は多いものの、細かい部分をカスタマイズできるので今後はますます採用比率が上がると個人的には考えています。
それでは良きPytorchライフを!

#5.追記
実はpytorch lightningを使えばearlystoppingの機能を実装しなくても使用できます。
以下も参考にどうぞ。

73
58
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
73
58

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?