Help us understand the problem. What is going on with this article?

StyleGAN2で未知のポケモンを生み出す[後編]

前回は概要や生成結果を示しました.今回は実際にGTX1070で動かすためにStyleGAN2の公式実装から変更した点などを紹介します.

諸事情で学習済みモデルが吹っ飛んだので,その辺の注意も含めてかいておきます!

StyleGAN2の公式実装

公式実装の変更点

実際にいじったのは主にrun_training.pyとtraining/dataset.pyです.

モデルのデータはだいたい300MBくらいあってかなり重いので,保存頻度を考えないと容量が尽きて学習が終わります.僕の場合,途中で学習が止まった上にモデルを上書きする設定にしていたため空上書きされて二日分の学習済みモデルが吹っ飛びました.定期的に複数のモデルで保存することをお勧めします.

run_training.py
def run(...):
  ...
  # 生成サンプルはtickごとに出力・ネットワークは5tickに一回
  train.image_snapshot_ticks = 1
  train.network_snapshot_ticks = 5

  # 今回は画像サイズ64でモデルを作成
  dataset_args = EasyDict(tfrecord_dir=dataset, resolution = 64)

こちらで紹介されていたメモリエラーの解消のための修正

training/dataset.py
class TFRecordDataset:
  def __init__(...):
    ...
    # Load labels.
    assert max_label_size == 'full' or max_label_size >= 0
    #self._np_labels = np.zeros([1<<30, 0], dtype=np.float32)
    self._np_labels = np.zeros([1<<20, 0], dtype=np.float32)

また,GPUの処理が長くなりすぎると学習タスクがキルされてしまいます.僕もこれにやられました.対策としてはDWORD値の設定でタイムアウトを切ればいいみたいです.(http://www.field-and-network.jp/rihei/20121028223437.php)
学習前に設定しておくことをお勧めします.

モデルの評価指標・学習の様子の監視

GANの定量的評価は難しい課題の一つですが,多くの研究ではFID(Frechet Inception Distance)と呼ばれる手法で生成画像の品質を評価しています.データセットの画像と生成画像を特徴量抽出モデルに入力して,その特徴量の分布間のFrechet距離を計算する手法です.

扱う多変量正規分布の次元数の関係でデータセットを最低でも4000枚程度はデータセットと生成画像から画像を用意する必要があり,特にデータセットの読み込みに時間がかかります.この処理を抜くと指標がなくなるので学習をいつ止めていいかもわからないので消せませんが,高速化できないものでしょうか...

FIDはresults/.../metrix-fid50k.txtに出力されていくので,順調に学習が進んでいるか定期的に確認していくことになります.

metrix-fid50k.txt
network-snapshot-              time 19m 34s      fid50k 278.0748
network-snapshot-              time 19m 34s      fid50k 382.7474
network-snapshot-              time 19m 34s      fid50k 338.3625
network-snapshot-              time 19m 24s      fid50k 378.2344
network-snapshot-              time 19m 33s      fid50k 306.3552
network-snapshot-              time 19m 33s      fid50k 173.8370
network-snapshot-              time 19m 30s      fid50k 112.3612
network-snapshot-              time 19m 31s      fid50k 99.9480
network-snapshot-              time 19m 35s      fid50k 90.2591
network-snapshot-              time 19m 38s      fid50k 75.5776
network-snapshot-              time 19m 39s      fid50k 67.8876
network-snapshot-              time 19m 39s      fid50k 66.0221
network-snapshot-              time 19m 46s      fid50k 63.2856
network-snapshot-              time 19m 40s      fid50k 64.6719
network-snapshot-              time 19m 31s      fid50k 64.2135
network-snapshot-              time 19m 39s      fid50k 63.6304
network-snapshot-              time 19m 42s      fid50k 60.5562
network-snapshot-              time 19m 36s      fid50k 59.4038
network-snapshot-              time 19m 36s      fid50k 57.2236
network-snapshot-              time 19m 40s      fid50k 56.9055
network-snapshot-              time 19m 47s      fid50k 56.5965
network-snapshot-              time 19m 34s      fid50k 56.5844
network-snapshot-              time 19m 38s      fid50k 56.4158
network-snapshot-              time 19m 34s      fid50k 54.0568
network-snapshot-              time 19m 32s      fid50k 54.0307
network-snapshot-              time 19m 40s      fid50k 54.0492
network-snapshot-              time 19m 32s      fid50k 54.1482
network-snapshot-              time 19m 38s      fid50k 53.3513
network-snapshot-              time 19m 32s      fid50k 53.8889
network-snapshot-              time 19m 39s      fid50k 53.5233
network-snapshot-              time 19m 40s      fid50k 53.9403
network-snapshot-              time 19m 43s      fid50k 53.1017
network-snapshot-              time 19m 39s      fid50k 53.3370
network-snapshot-              time 19m 36s      fid50k 53.0706
network-snapshot-              time 19m 43s      fid50k 52.6289
network-snapshot-              time 19m 39s      fid50k 51.8526
network-snapshot-              time 19m 35s      fid50k 52.3760
network-snapshot-              time 19m 42s      fid50k 52.7780
network-snapshot-              time 19m 36s      fid50k 52.3064
network-snapshot-              time 19m 42s      fid50k 52.4976

もし学習が途中で止まってしまったら

僕のように何らかのエラーで学習が止まってしまった場合,results以下に保存されているnetwork-snapshot-*.pklのモデルを読み込むことでそこから学習をやり直すことができます.
必要な記述は以下の通り.

run_training.py
def run(...):
  ...
  train.resume_pkl = "./results/00000-stylegan2-tf_images-1gpu-config-f/network-snapshot-00640.pkl"
  train.resume_kimg = 640
  train.resume_time = 150960

train.resume_timeはlog.txtなどからモデルが保存されたときに出力されている計算時間をsecに直して入力すれば大丈夫です.

多分転移学習も同様に手法で,学習済モデルを指定してあげればできると思います.重みの固定化とか細かいことはモデルを読み込んでから手動で設定する必要があります.今回みたいなタスクで転移させる場合には全部再学習してしまってもいい気がしてますが...

log.txt
dnnlib: Running training.training_loop.training_loop() on localhost...
...
tick 40    kimg 640.1    lod 0.00  minibatch 32   time 1d 17h 56m   sec/tick 2588.1  sec/kimg 161.76  maintenance 1203.1 gpumem 5.1

僕が学習できたところまでの結果

640kimgでストレージ不足で落ちました(泣)
いい感じにFIDが下がっていたので悲しすぎます.
fakes000640.png

結構輪郭もはっきりしてきて,体の形だけじゃなくて顔みたいなものも再現され始めてますね.

早くこの先の学習結果が見たいのですが,学習し直しているので当分先になりそうです.学習が進み次第結果を追記します.

本家の実装からの変更点をすべて紹介しきれているか不安なので,コードを以下に挙げておきました.使い方もGitHubに書いてあるのでどうぞ.

https://github.com/Takuya-Shuto-engineer/PokemonGAN

参考文献

Takuya-Shuto-engineer
高専→大学→大学院のエンジニア志望 情報通信(物理層)の研究を経て、画像認識と画像生成に関する技術について研究をしています. 顔学・ゲームのデータ分析に関心を持っています.エンジニアを目指して修行中.
techtrain
プロのエンジニアを目指すU30(30歳以下)の方に現役エンジニアにメンタリングもらえるコミュニティです。
https://techbowl.co.jp/techtrain/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away