前提
(※2020年1月9日の記事です。近い将来問題なくなると思います。)
Pytorch LightningはPyTorchのKerasライクなラッパです。
モデル・学習・データ周りをコンパクトに書けることが魅力です。
詳しくは@fam_taroさんの下記記事をご覧ください。
PyTorch 三国志(Ignite・Catalyst・Lightning) - Qiita
とても便利な感じがするのですが、インストール早々バグにぶち当たってしまったので、その内容と解決方法の報告です。
環境
OS: macOS 10.14.6
Python: 3.7.3
pytorch-lightning: 0.5.3.2
Pytorch Lightningのインストール方法: pip install pytorch-lightning
バグ
Pytorch LightningにはEarly Stoppingが実装されていて、
以下のようなコードで記述できます(スッキリ)。
import pytorch_lightning as pl
class MyModel(pl.LightningModule):
...
def validation_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'val_batch_loss': F.cross_entropy(y_hat, y)}
def validation_end(self, outputs):
val_loss = torch.stack([x['val_batch_loss'] for x in outputs]).mean()
log = {'val_loss': val_loss}
return {'log': log}
...
early_stop_callback = EarlyStopping(
min_delta=0.00,
patience=1,
verbose=False,
monitor='val_loss',
mode='min',
)
model = MyModel()
trainer = pl.Trainer(early_stop_callback=early_stop_callback)
trainer.fit(model)
しかし、いざ実行すると、以下のようなエラーが出てきてうまく動きませんでした。
(常に発生するのかは把握できていませんが、とりあえず私の実行環境では常に再現されました。)
Early stopping conditioned on metric `val_loss` which is not available. Available metrics are: loss,train_loss
このバグは公式のIssueでも報告されています。
https://github.com/williamFalcon/pytorch-lightning/issues/490
解決方法
最新のmasterブランチでは修正されているので、下記コマンドでインストールするとバグが直ります。
pip install git+https://github.com/williamFalcon/pytorch-lightning.git@master --upgrade
注意点
最新ブランチをインストールすると、2020年1月9日現在のドキュメントとAPIの齟齬が生じる場合がありそうです。
例: チェックポイント保存用クラスpytorch_lightning.callbacks.ModelCheckpoint
の初期化メソッドの引数変更(該当ページ)
-
pip install pytorch-lightning
でインストールした場合(公式ドキュメントと同じ)-
save_best_only
: 最良モデルのみを保存するか否かをBool値で指定
-
- 「解決方法」のコマンドで最新ブランチをインストールした場合(公式ドキュメントと異なる)
-
save_top_k
: 上位何個を保存するか整数で指定 - (
save_best_only
は取って代わられる形で削除)
-
恐らく、0.5.3.2を超えるバージョンが公開されたら、通常のpip install pytorch-lightning
で大丈夫になると思います。