昨今 midjourney
、 Stable Diffusion
といったAIによる画像生成技術の話題が盛り上がっており、画像生成技術の急激な発展には驚かされます。
本稿では、GANモデルのブレイクスルーである StyleGAN2 の実装コードを見ることで、AIによる画像生成技術の理解をさらに深めることを目指し、GitHubに公開されている StyleGAN2 の実装コードの解説を行います。
今回は 画像生成時のコード を解説します。
StyleGAN2のアーキテクチャの詳細については以下の記事に記載されていますので、一緒に見ると理解が進むかと思います。
学習時のコードの解説はこちらになります。
全てのコードを細かく説明するのではなく概ねの動作がわかるように解説してしています。
記載内容等に間違えがあれば随時ブラッシュアップしますのでご指摘をお願いします。
StyleGAN2のライブララリィ構成
画像生成に関連するライブラリィ、モジュールの構成と概要は以下の通りです。
説明 | ||||
---|---|---|---|---|
画像生成バッチメイン | run_generator.py | |||
DNN用パッケージ | $dnnlib$ | |||
実行用パッケージ | $submission$ | |||
内部用パッケージ | $internal$ | |||
submitのラッパモジュール | local.py | |||
学習や生成のループを管理するためのヘルパーモジュール | run_context.py | |||
実行モジュール | submit.py | |||
Tensorflow用パッケージ | $tflib$ | |||
画像や「活性化関数+バイアスを融合」させたオペレーションパッケージ | $ops$ | |||
活性化関数+バイアスを融合させたオペレーション(CUDA) | fused_bias_act.cu | |||
活性化関数+バイアスを融合させたオペレーション | fused_bias_act.py | |||
各種画像処理(CUDA) | upfirdn_2d.cu | |||
各種画像処理モジュール | upfirdn_2d.py | |||
Tensorboardのヘルパーモジュール | autosummary.py | |||
Tensorflowのカスタムオペレーションモジュール | custom_ops.py | |||
Tensorflowのネットワークラッパーモジュール | network.py | |||
オプティマイザモジュール | optimizer.py | |||
各種Tensorflowユーティリティモジュール | tfutil.py | |||
ユーティリティモジュール | util.py | |||
各種メトリックスパッケージ(学習には直接は関係ないのでモジュール構成は省略) | $metrics$ |
実装コードの解説
画像生成時のコードの流れを以下に順を追って説明します(Google Colab Proで実行して確認しています)。
画像生成の起動
まずは画像生成バッチ run_generator.py を起動します。 下記は、公式ホームページで顔画像を生成する際の起動コマンドサンプルです。
python run_generator.py generate-images --network=gdrive:networks/stylegan2-ffhq-config-f.pkl --seeds=6600-6625 --truncation-psi=0.5
ここで、 --network
で指定するのは学習済みのモデルになります。
実際には、 画像生成の開始時にpretrained_network.py
の 64行目load_networks()
が呼び出され 58行目の gdrive_urls.get()
で事前に定義されたディクショナリで該当するモデルが選択され(今回の場合は32行目)、該当モデルがネットワーク経由でダウンロード(76行目)されます。
・・
17 gdrive_urls = {
'gdrive:networks/stylegan2-car-config-a.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-car-config-a.pkl',
・・
30 'gdrive:networks/stylegan2-ffhq-config-d.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-ffhq-config-d.pkl',
31 'gdrive:networks/stylegan2-ffhq-config-e.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-ffhq-config-e.pkl',
32 'gdrive:networks/stylegan2-ffhq-config-f.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-ffhq-config-f.pkl',
・・
52 'gdrive:networks/table2/stylegan2-ffhq-config-e-Gskip-Dskip.pkl': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gskip-Dskip.pkl',
53 }
・・
57 def get_path_or_url(path_or_gdrive_path):
58 return gdrive_urls.get(path_or_gdrive_path, path_or_gdrive_path)
・・
64 def load_networks(path_or_gdrive_path):
65 path_or_url = get_path_or_url(path_or_gdrive_path)
66 if path_or_url in _cached_networks:
67 return _cached_networks[path_or_url]
68
69 if dnnlib.util.is_url(path_or_url):
70 stream = dnnlib.util.open_url(path_or_url, cache_dir='.stylegan2-cache')
71 else:
72 stream = open(path_or_url, 'rb')
73
74 tflib.init_tf()
75 with stream:
76 G, D, Gs = pickle.load(stream, encoding='latin1')
77 _cached_networks[path_or_url] = G, D, Gs
78 return G, D, Gs
run_generator.py
が起動されるとまず同モジュールの main()
メソッドが呼ばれます。
main()
メソッドでは起動パラメータのチェックを行い、dnnlib.submit_run()
メソッド(163行目)を呼び出します。
119 def main():
120 parser = argparse.ArgumentParser(
121 description='''StyleGAN2 generator.
122
123 Run 'python %(prog)s <subcommand> --help' for subcommand help.''',
124 epilog=_examples,
125 formatter_class=argparse.RawDescriptionHelpFormatter
126 )
・・
159 func_name_map = {
160 'generate-images': 'run_generator.generate_images',
161 'style-mixing-example': 'run_generator.style_mixing_example'
162 }
163 dnnlib.submit_run(sc, func_name_map[subcmd], **kwargs)
submit_run()
では、画像生成の準備を行い、 farm.submit()
メソッドを呼び出します(学習時と共通処理)。 farm.submit()
から dnnlib.submission.submit
モジュールの run_wrapper()
メソッドが呼ばれます。
310 def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None:
・・
343 return farm.submit(submit_config, host_run_dir)
run_wrapper()
の274行目の run_func_obj()
メソッドでは、run_func_name で指定された名前から util.get_obj_by_name() メソッドを使用して動的に呼び出しメソッドを生成しますが、これは、run_generator
モジュールの160行目の generate-images
に指定された run_generator
モジュールの generate_images()
メソッドが呼び出されることになります。
この generate_images()
が画像生成全体をコントロールするメソッドになります。
256 def run_wrapper(submit_config: SubmitConfig) -> None:
・・
274 run_func_obj = util.get_obj_by_name(submit_config.run_func_name)
・・
277 if 'submit_config' in sig.parameters:
278 run_func_obj(submit_config=submit_config, **submit_config.run_func_kwargs)
279 else:
run_func_obj(**submit_config.run_func_kwargs)
・・
310 def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None:
・・
343 return farm.submit(submit_config, host_run_dir)
# local.py経由でrun_wrapper()を呼び出す
256 def run_wrapper(submit_config: SubmitConfig) -> None:
"""Wrap the actual run function call for handling logging, exceptions, typing, etc."""
・・
270 try:
print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name))
・・
277 if 'submit_config' in sig.parameters:
278 run_func_obj(submit_config=submit_config, **submit_config.run_func_kwargs)
279 else:
280 run_func_obj(**submit_config.run_func_kwargs)
・・
282 print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time)))
画像の生成
generate_images()
では、先程の retrained_networks
のモジュールの load_networks()
メソッドで学習済みのモデルがロードされます(21行目)。ここで、_G
はGenerator(Mapping +
Synthesisモデル)、_D
はDiscriminatorモデル、Gs
は、Generatorの各inerationにおける値の指数移動平均を撮ったもので _G
より安定したモデルとして画像生成に使われています。
✳︎_G
と_D
モデルは画像生成では使用しません。
キーワードパラメータの編集を行った後、起動時に指定されたseed値分ループします(30行目)。
そしてMappingモデル入力する潜在変数(z)およびSynthesisモデルに入力するノイズの生成を行い(32〜34行目)、Generatorを起動します(35行目)。
Generatorの起動結果として生成されたFake画像の配列が返却されますので、起動時のパラメータ --result-dir
で指定されたresultフォルダ(デフォルトは results
になります)に、順次シード名を付与してRGBフォーマットで格納していきます(36行目)。
19 def generate_images(network_pkl, seeds, truncation_psi):
20 print('Loading networks from "%s"...' % network_pkl)
21 _G, _D, Gs = pretrained_networks.load_networks(network_pkl)
22 noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]
23
24 Gs_kwargs = dnnlib.EasyDict()
25 Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
26 Gs_kwargs.randomize_noise = False
27 if truncation_psi is not None:
28 Gs_kwargs.truncation_psi = truncation_psi
29
30 for seed_idx, seed in enumerate(seeds):
31 print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
32 rnd = np.random.RandomState(seed)
33 z = rnd.randn(1, *Gs.input_shape[1:]) # [minibatch, component]
34 tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width]
35 images = Gs.run(z, None, **Gs_kwargs) # [minibatch, height, width, channel]
36 PIL.Image.fromarray(images[0], 'RGB').save(dnnlib.make_run_dir_path('seed%04d.png' % seed))
おわりに
以上が、StyleGAN2の画像生成コードの概要になります。