広島大学で社会基盤(土木)を専攻している3年生です。
現在は1年間休学して建設系のIT企業でインターンをしています。
以前、超解像度モデルであるSRGANとESRGANを学習と可視化を行ってみました。
超解像度がどんなものか試してみたい方はよければ見てみてください。
今回はpytorch lightningを用いて学習の可視化、学習途中の推論画像の可視化、ウェイトファイルの作成を行いました。
pytorch lightning(URL)とは
簡潔に言うとpytorchでのモデル学習を簡潔にしてくれるライブラリです。pytorchで使えるラッパーにはigniteもあるので今後使用してみたいと思います。
こちらのURLにGANでpytorch lightningを使用した例が載せられているので参考にしました。
実装コード・感想
実装コードはこちらのgithubです。
通常GANの学習では生成器と識別器で交互に学習します。ESRGANはwarmup_batchsを設けているため指定したイテレータ数は生成器だけを学習するためpytorch lightningで学習させる時には予め学習させておいたウェイトファイルを読み込ませた生成器と用いて学習を行いました。一応ウェイトファイルのリンクをあげておきます。よかったら使ってください。
少し手間だったのでもっと簡単なやり方があったらコメントで教えていただきたいです、、、。
あえて2回学習しました。
こんな感じでtensorboardを使うことができます。
なぜか画質がドライブに保存したものと違う感じになってしまっていますが学習途中の画像を可視化しました、、、
あとversion_0の方がglobal_stepを引数に入れるのを忘れてpytorch lightningのサンプルものをそのまま使ってしまいましたが見逃してください。
感想
パラメータの調整や学習経過などを比較しやすいためこれからkaggleや勉強の際には積極的に使っていくと慣れでどんどん使いやすくなる感じがしています。