GANを学ぶ機会があり、どうせなら新しい技術まで知りたいと思ったのでStyle GANについて実装レベルで理解しようと考えました。
TensorFlowとCUDAのバージョン問題でイライラしていたこともあり、PyTorchで実装しているコードがないかと調査したところ、下記の記事があったため、ここのコードを読み込むことにしました。
コードは理解しやすかったのですが、
- Trainer部分が入り組んでいた
- Shape(テンソルの形状)を追いかけたかった
- Style GANに使われている個別技術(PixelNormalizationなど)の仕組みを理解したかった
ということで、一部クラスを組み換えつつdocstringコメントを追加するなどして遊びました。
実装
- 動作環境
- GPU: GTX750Ti
- OS: Windows10
- Python: 3.7.9
- PyTorch: 1.5.0
- CUDA: 10.1
CUDAのについては自分が書いたこの記事を参考に、「グラフィックドライバのアップデート」と「CUDA Toolkitのインストール」を実行しておく必要があります。
なお、NVIDIA製グラフィックボードがない場合はGPU演算できません。
GPU演算ができなくても実行できますが、1バッチ通すのに数秒かかるので大人しくGPU環境を導入した方がよいでしょう。
ソースコードはこちらです↓
GitHub: StyleGanPytorch
設定はsetting.jsonから変更できます。
デフォルトの設定では第7層のバッチサイズを2まで落としているのでまともに学習が進みません。
GPUに余裕がある場合はこの値を大きくしましょう。
GPUスペック以上に大きくすると、メモリが確保できずエラー落ちします。
今後の展望
クラスの組み換えとコメント部分以外は前述記事のコードほぼそのままなので恐縮ですが、大変よい勉強にはなりました。
ちなみに機械学習自体の動作に関しては、私のPCのGPUがGTX750Tiということもあり、大幅にバッチサイズを落とさざるを得ませんでした…
このため、記事通りのモデルでは間違いなく学習が終わりません(笑)
後ほど、新しいPCを購入するまでのつなぎとして、モデルのパラメータを減らすなどの作業にチャレンジしてみようと思います。
イラスト生成などできるといいよなあ…
なお、Style GANの原論文で用いられているモデルをそのまま用いる場合、クラウドに湯水の如く金を落とさないと再現実験もできないと思いますので最初から諦めています。
(リアルな画像生成系はスペックとの戦いですよね…)