今回やること
何番煎じだという感じですが,ポケモンの自動生成モデルの学習をしようと思います.調べてみると最新のモデルのStyleGAN2ベースですでに試している方がいます.しかし実際のコードやデータセット,実装までは公開されていなかったのでこのへんを記事にまとめておきたいなと思っています.
StyleGAN2とは
StyleGAN2はNVIDIAが発表した画像生成モデルです.スタイル変換を用いることが特徴的な生成モデルで,現状では複数のタスクにおいてSOTAとなっている強力なモデルです.
データセット
何か良いデータセットはないものかと調べていると,事前にデータを15000件収集して公開してくれている方がいたのでこちらを使わせていただきました.
MonsterGAN
ほかにもKaggleのPokemon-Image-DatasetやOne-Shot-Pokemon-Imagesなども候補に挙がります.One-Shot-Pokemon-Imagesにはポケモンカードのデータセットが含まれていて,そちらが非常にデータ量が多いためポケモンカード生成タスクに応用されているようです.
※コチラより引用
先行事例
Michael Frieseという方が熱心にStyleGAN・StyleGAN2を用いたポケモン生成に挑戦していらっしゃるようです.
StyleGANでは既にかなりいい感じに結果を残されているようです.
※ コチラより引用
猫や馬などの学習済みモデルから転移学習を行うと,学習自体がうまくいく上に生成画像に転移前のドメインの雰囲気が反映されるらしい.
すごくエモい.
馬から転移
※ コチラから引用
猫から転移
※ コチラから引用
車から転移
※ コチラから引用
StyleGAN2を用いた生成結果は途中までしか公開されていませんでした.540kimgまで学習した結果です.学習途中なので品質は置いておいて,かなりバリエーションが多く感じます.
※ コチラから引用
自前で行った学習結果
まだ学習途中なのですが,StyleGAN2を使った学習の経過をお見せします.RTX1070で学習を行っているため,むちゃなモデルはメモリサイズの関係で不可能です.今回はデータセットをすべて64x64にリサイズして利用しました.
モデルの概要
G Params OutputShape WeightShape
--- --- --- ---
latents_in - (?, 512) -
labels_in - (?, 0) -
lod - () -
dlatent_avg - (512,) -
G_mapping/latents_in - (?, 512) -
G_mapping/labels_in - (?, 0) -
G_mapping/Normalize - (?, 512) -
G_mapping/Dense0 262656 (?, 512) (512, 512)
G_mapping/Dense1 262656 (?, 512) (512, 512)
G_mapping/Dense2 262656 (?, 512) (512, 512)
G_mapping/Dense3 262656 (?, 512) (512, 512)
G_mapping/Dense4 262656 (?, 512) (512, 512)
G_mapping/Dense5 262656 (?, 512) (512, 512)
G_mapping/Dense6 262656 (?, 512) (512, 512)
G_mapping/Dense7 262656 (?, 512) (512, 512)
G_mapping/Broadcast - (?, 10, 512) -
G_mapping/dlatents_out - (?, 10, 512) -
Truncation/Lerp - (?, 10, 512) -
G_synthesis/dlatents_in - (?, 10, 512) -
G_synthesis/4x4/Const 8192 (?, 512, 4, 4) (1, 512, 4, 4)
G_synthesis/4x4/Conv 2622465 (?, 512, 4, 4) (3, 3, 512, 512)
G_synthesis/4x4/ToRGB 264195 (?, 3, 4, 4) (1, 1, 512, 3)
G_synthesis/8x8/Conv0_up 2622465 (?, 512, 8, 8) (3, 3, 512, 512)
G_synthesis/8x8/Conv1 2622465 (?, 512, 8, 8) (3, 3, 512, 512)
G_synthesis/8x8/Upsample - (?, 3, 8, 8) -
G_synthesis/8x8/ToRGB 264195 (?, 3, 8, 8) (1, 1, 512, 3)
G_synthesis/16x16/Conv0_up 2622465 (?, 512, 16, 16) (3, 3, 512, 512)
G_synthesis/16x16/Conv1 2622465 (?, 512, 16, 16) (3, 3, 512, 512)
G_synthesis/16x16/Upsample - (?, 3, 16, 16) -
G_synthesis/16x16/ToRGB 264195 (?, 3, 16, 16) (1, 1, 512, 3)
G_synthesis/32x32/Conv0_up 2622465 (?, 512, 32, 32) (3, 3, 512, 512)
G_synthesis/32x32/Conv1 2622465 (?, 512, 32, 32) (3, 3, 512, 512)
G_synthesis/32x32/Upsample - (?, 3, 32, 32) -
G_synthesis/32x32/ToRGB 264195 (?, 3, 32, 32) (1, 1, 512, 3)
G_synthesis/64x64/Conv0_up 2622465 (?, 512, 64, 64) (3, 3, 512, 512)
G_synthesis/64x64/Conv1 2622465 (?, 512, 64, 64) (3, 3, 512, 512)
G_synthesis/64x64/Upsample - (?, 3, 64, 64) -
G_synthesis/64x64/ToRGB 264195 (?, 3, 64, 64) (1, 1, 512, 3)
G_synthesis/images_out - (?, 3, 64, 64) -
G_synthesis/noise0 - (1, 1, 4, 4) -
G_synthesis/noise1 - (1, 1, 8, 8) -
G_synthesis/noise2 - (1, 1, 8, 8) -
G_synthesis/noise3 - (1, 1, 16, 16) -
G_synthesis/noise4 - (1, 1, 16, 16) -
G_synthesis/noise5 - (1, 1, 32, 32) -
G_synthesis/noise6 - (1, 1, 32, 32) -
G_synthesis/noise7 - (1, 1, 64, 64) -
G_synthesis/noise8 - (1, 1, 64, 64) -
images_out - (?, 3, 64, 64) -
--- --- --- ---
Total 27032600
D Params OutputShape WeightShape
--- --- --- ---
images_in - (?, 3, 64, 64) -
labels_in - (?, 0) -
64x64/FromRGB 2048 (?, 512, 64, 64) (1, 1, 3, 512)
64x64/Conv0 2359808 (?, 512, 64, 64) (3, 3, 512, 512)
64x64/Conv1_down 2359808 (?, 512, 32, 32) (3, 3, 512, 512)
64x64/Skip 262144 (?, 512, 32, 32) (1, 1, 512, 512)
32x32/Conv0 2359808 (?, 512, 32, 32) (3, 3, 512, 512)
32x32/Conv1_down 2359808 (?, 512, 16, 16) (3, 3, 512, 512)
32x32/Skip 262144 (?, 512, 16, 16) (1, 1, 512, 512)
16x16/Conv0 2359808 (?, 512, 16, 16) (3, 3, 512, 512)
16x16/Conv1_down 2359808 (?, 512, 8, 8) (3, 3, 512, 512)
16x16/Skip 262144 (?, 512, 8, 8) (1, 1, 512, 512)
8x8/Conv0 2359808 (?, 512, 8, 8) (3, 3, 512, 512)
8x8/Conv1_down 2359808 (?, 512, 4, 4) (3, 3, 512, 512)
8x8/Skip 262144 (?, 512, 4, 4) (1, 1, 512, 512)
4x4/MinibatchStddev - (?, 513, 4, 4) -
4x4/Conv 2364416 (?, 512, 4, 4) (3, 3, 513, 512)
4x4/Dense0 4194816 (?, 512) (8192, 512)
Output 513 (?, 1) (512, 1)
scores_out - (?, 1) -
--- --- --- ---
Total 26488833
snapshotのグリッドの設定の変更を忘れていてめちゃくちゃ画像みにくくすみません...学習が大変なのでやり直しはまだです.
生成結果(288kimg: 19時間)
徐々に輪郭が形成されてきて,人型や動物型にも見えるポケモンの概形できてきました.
学習がしっかり進んでいるかどうかFrechet Inception Distance(FID)を出力して監視しています.今のところ順調に進んでいます.公式のページにはFIDは一桁だと載っているのですが,さすがにそこまで学習するのにコストが高すぎるので自分の目視で生成画像がいい感じになったら学習を止めるつもりです.
network-snapshot- time 19m 34s fid50k 278.0748
network-snapshot- time 19m 34s fid50k 382.7474
network-snapshot- time 19m 34s fid50k 338.3625
network-snapshot- time 19m 24s fid50k 378.2344
network-snapshot- time 19m 33s fid50k 306.3552
network-snapshot- time 19m 33s fid50k 173.8370
network-snapshot- time 19m 30s fid50k 112.3612
network-snapshot- time 19m 31s fid50k 99.9480
network-snapshot- time 19m 35s fid50k 90.2591
network-snapshot- time 19m 38s fid50k 75.5776
network-snapshot- time 19m 39s fid50k 67.8876
network-snapshot- time 19m 39s fid50k 66.0221
network-snapshot- time 19m 46s fid50k 63.2856
network-snapshot- time 19m 40s fid50k 64.6719
network-snapshot- time 19m 31s fid50k 64.2135
network-snapshot- time 19m 39s fid50k 63.6304
network-snapshot- time 19m 42s fid50k 60.5562
network-snapshot- time 19m 36s fid50k 59.4038
network-snapshot- time 19m 36s fid50k 57.2236
まとめ
今回はデータセットと先行事例を調査して,実際に学習を進めてみたところまでを書いてみました.
自分の環境で動いたコードやもっと学習が進んだ後の生成結果などは後日の後編の記事としてまとめます.
後編→ StyleGAN2で未知のポケモンを生み出す[後編]
追記: 学習過程の報告
実は,ここまで学習が進んだところでCドライブで実行していたためデータ容量が足りなくなってしまいモデルが空保存されて消えてしまいました...
反省して読み出し速度を犠牲にして増設したHDD上に移して再実行します.辛すぎる...二日間の学習結果が....
何も考えてなかった僕が悪いんですが,共感してくれる人はなぐさめて...
モデルの保存は容量を気にして最新のものだけを毎回上書き保存していたのですが,それが裏目に出ました.HDDにはこまめにモデルを保存していこうと思います.