PyTorch Lightningは、ML開発者向けの軽量なPyTorchラッパ。
学習とかがめっちゃ簡単で、他のいろんなプラットフォームとの組み合わせが使えて汎用性高いってのを聞いたからとりあえず使ってみた。
いろんな機能があったけど、これだけ理解しといたらいけるやろって感じの機能だけ紹介します。
・基本的な使い方
・ログ機能
・CALLBACK
この三つの使い方を今回は書いていきます
PyTorchの使い方知らない人にもわかるように記載しますが、PyTorchも知ってた方が分かりやすいと思う。
##基本的な使い方
簡潔に言うと学習までの手順ですることはこの三つ
・モデルを作る
・推論時の設定
・学習開始
機械学習においてすることはだいたいこれだけ
しかし、PyTorch Lightningは通常のPyTorchによる学習よりもコードが非常に簡潔!
モデル作成のステップでしっかりと記述できれば非常に便利なので、ぜひ最後まで読んでください。
###モデルを作る
もっとも大事と言えるこのステップ。通常PyTorch等の機械学習では、データの用意や各学習ステップの設計が必要ですが、PyTorch Lightningではこれらの必要な処理をほとんどモデルの内部で行います。
では実際にモデルを定義していきましょう。
分かりやすいようにMNISTのデータを用いて解説していきます。
モデル作成で、継承するPyTorch Lightningのモジュールは、
pytorch_lightning.LightningModuleというものです。
これを利用して以下のようにモデルを定義します。
続いて先ほど定義したモデルの中に以下の関数を定義します。