3
3

More than 3 years have passed since last update.

pytorch-lightning (0.5.3.2) でEarly Stoppingを使用したときに'val_loss'が見つからないと言われるバグ

Last updated at Posted at 2020-01-09

前提

(※2020年1月9日の記事です。近い将来問題なくなると思います。)

Pytorch LightningはPyTorchのKerasライクなラッパです。
モデル・学習・データ周りをコンパクトに書けることが魅力です。

詳しくは@fam_taroさんの下記記事をご覧ください。

PyTorch 三国志(Ignite・Catalyst・Lightning) - Qiita

とても便利な感じがするのですが、インストール早々バグにぶち当たってしまったので、その内容と解決方法の報告です。

環境

OS: macOS 10.14.6
Python: 3.7.3
pytorch-lightning: 0.5.3.2
Pytorch Lightningのインストール方法: pip install pytorch-lightning

バグ

Pytorch LightningにはEarly Stoppingが実装されていて、
以下のようなコードで記述できます(スッキリ)。

モデル定義部分(抜粋)
import pytorch_lightning as pl
class MyModel(pl.LightningModule):
    ...
    def validation_step(self, batch, batch_nb):
        x, y = batch
        y_hat = self.forward(x)
        return {'val_batch_loss': F.cross_entropy(y_hat, y)}

    def validation_end(self, outputs):
        val_loss = torch.stack([x['val_batch_loss'] for x in outputs]).mean()
        log = {'val_loss': val_loss}
        return {'log': log}
    ...
Early_Stopping周り
early_stop_callback = EarlyStopping(
    min_delta=0.00,
    patience=1,
    verbose=False,
    monitor='val_loss',
    mode='min',
)
model = MyModel()
trainer = pl.Trainer(early_stop_callback=early_stop_callback)
trainer.fit(model) 

しかし、いざ実行すると、以下のようなエラーが出てきてうまく動きませんでした。
(常に発生するのかは把握できていませんが、とりあえず私の実行環境では常に再現されました。)

Early stopping conditioned on metric `val_loss` which is not available. Available metrics are: loss,train_loss

このバグは公式のIssueでも報告されています。

https://github.com/williamFalcon/pytorch-lightning/issues/490

解決方法

最新のmasterブランチでは修正されているので、下記コマンドでインストールするとバグが直ります。

pip install git+https://github.com/williamFalcon/pytorch-lightning.git@master --upgrade

注意点

最新ブランチをインストールすると、2020年1月9日現在のドキュメントとAPIの齟齬が生じる場合がありそうです。

例: チェックポイント保存用クラスpytorch_lightning.callbacks.ModelCheckpointの初期化メソッドの引数変更(該当ページ)

  • pip install pytorch-lightningでインストールした場合(公式ドキュメントと同じ)
    • save_best_only: 最良モデルのみを保存するか否かをBool値で指定
  • 「解決方法」のコマンドで最新ブランチをインストールした場合(公式ドキュメントと異なる)
    • save_top_k: 上位何個を保存するか整数で指定
    • (save_best_onlyは取って代わられる形で削除)

恐らく、0.5.3.2を超えるバージョンが公開されたら、通常のpip install pytorch-lightningで大丈夫になると思います。

参考

3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3