LoginSignup
41
23

More than 1 year has passed since last update.

画像生成モデルをマスタしよう! StyleGAN2のコードを解説(画像生成編)

Last updated at Posted at 2022-12-04

昨今 midjourneyStable DiffusionといったAIによる画像生成技術の話題が盛り上がっており、画像生成技術の急激な発展には驚かされます。

本稿では、GANモデルのブレイクスルーである StyleGAN2 の実装コードを見ることで、AIによる画像生成技術の理解をさらに深めることを目指し、GitHubに公開されている StyleGAN2 の実装コードの解説を行います。

今回は 画像生成時のコード を解説します。
 
StyleGAN2のアーキテクチャの詳細については以下の記事に記載されていますので、一緒に見ると理解が進むかと思います。

学習時のコードの解説はこちらになります。

全てのコードを細かく説明するのではなく概ねの動作がわかるように解説してしています。
記載内容等に間違えがあれば随時ブラッシュアップしますのでご指摘をお願いします。

StyleGAN2のライブララリィ構成

  画像生成に関連するライブラリィ、モジュールの構成と概要は以下の通りです。

説明                          
画像生成バッチメイン run_generator.py
DNN用パッケージ $dnnlib$
実行用パッケージ $submission$
内部用パッケージ $internal$
submitのラッパモジュール local.py
学習や生成のループを管理するためのヘルパーモジュール run_context.py
実行モジュール submit.py
Tensorflow用パッケージ $tflib$
画像や「活性化関数+バイアスを融合」させたオペレーションパッケージ $ops$
活性化関数+バイアスを融合させたオペレーション(CUDA) fused_bias_act.cu
活性化関数+バイアスを融合させたオペレーション fused_bias_act.py
各種画像処理(CUDA) upfirdn_2d.cu
各種画像処理モジュール upfirdn_2d.py
Tensorboardのヘルパーモジュール autosummary.py
Tensorflowのカスタムオペレーションモジュール custom_ops.py
Tensorflowのネットワークラッパーモジュール network.py
オプティマイザモジュール optimizer.py
各種Tensorflowユーティリティモジュール tfutil.py
ユーティリティモジュール util.py
各種メトリックスパッケージ(学習には直接は関係ないのでモジュール構成は省略) $metrics$

実装コードの解説

画像生成時のコードの流れを以下に順を追って説明します(Google Colab Proで実行して確認しています)。

画像生成の起動

まずは画像生成バッチ run_generator.py を起動します。 下記は、公式ホームページで顔画像を生成する際の起動コマンドサンプルです。 

python run_generator.py generate-images --network=gdrive:networks/stylegan2-ffhq-config-f.pkl --seeds=6600-6625 --truncation-psi=0.5

ここで、 --network で指定するのは学習済みのモデルになります。
実際には、 画像生成の開始時にpretrained_network.pyの 64行目load_networks() が呼び出され 58行目の gdrive_urls.get() で事前に定義されたディクショナリで該当するモデルが選択され(今回の場合は32行目)、該当モデルがネットワーク経由でダウンロード(76行目)されます。

pretrained_network.py
     ・・
 17 gdrive_urls = {
    'gdrive:networks/stylegan2-car-config-a.pkl':                           'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-car-config-a.pkl',
    ・・   
 30     'gdrive:networks/stylegan2-ffhq-config-d.pkl':                          'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-ffhq-config-d.pkl',
 31     'gdrive:networks/stylegan2-ffhq-config-e.pkl':                          'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-ffhq-config-e.pkl',
 32     'gdrive:networks/stylegan2-ffhq-config-f.pkl':                          'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-ffhq-config-f.pkl',
    ・・
 52     'gdrive:networks/table2/stylegan2-ffhq-config-e-Gskip-Dskip.pkl':       'https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gskip-Dskip.pkl',
 53 }
    ・・
 57 def get_path_or_url(path_or_gdrive_path):
 58     return gdrive_urls.get(path_or_gdrive_path, path_or_gdrive_path)
    ・・
 64 def load_networks(path_or_gdrive_path):
 65     path_or_url = get_path_or_url(path_or_gdrive_path)
 66     if path_or_url in _cached_networks:
 67         return _cached_networks[path_or_url]
 68
 69     if dnnlib.util.is_url(path_or_url):
 70         stream = dnnlib.util.open_url(path_or_url, cache_dir='.stylegan2-cache')
 71     else:
 72         stream = open(path_or_url, 'rb')
 73
 74     tflib.init_tf()
 75     with stream:
 76         G, D, Gs = pickle.load(stream, encoding='latin1')
 77     _cached_networks[path_or_url] = G, D, Gs
 78     return G, D, Gs

run_generator.py が起動されるとまず同モジュールの main() メソッドが呼ばれます。
main() メソッドでは起動パラメータのチェックを行い、dnnlib.submit_run() メソッド(163行目)を呼び出します。

run_generator.py
119 def main():
120     parser = argparse.ArgumentParser(
121         description='''StyleGAN2 generator.
122
123 Run 'python %(prog)s <subcommand> --help' for subcommand help.''',
124         epilog=_examples,
125         formatter_class=argparse.RawDescriptionHelpFormatter
126     )
     ・・
159     func_name_map = {
160         'generate-images': 'run_generator.generate_images',
161         'style-mixing-example': 'run_generator.style_mixing_example'
162     }
163     dnnlib.submit_run(sc, func_name_map[subcmd], **kwargs)

submit_run() では、画像生成の準備を行い、 farm.submit()メソッドを呼び出します(学習時と共通処理)。 farm.submit()から dnnlib.submission.submit モジュールの run_wrapper()メソッドが呼ばれます。

dnnlib.submission.submit.py
310 def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None:
・・

343 return farm.submit(submit_config, host_run_dir)

run_wrapper() の274行目の run_func_obj() メソッドでは、run_func_name で指定された名前から util.get_obj_by_name() メソッドを使用して動的に呼び出しメソッドを生成しますが、これは、run_generator モジュールの160行目の generate-imagesに指定された run_generator モジュールの generate_images() メソッドが呼び出されることになります。

この generate_images() が画像生成全体をコントロールするメソッドになります。

dnnlib.submission.submit.py
256 def run_wrapper(submit_config: SubmitConfig) -> None:
    ・・
274        run_func_obj = util.get_obj_by_name(submit_config.run_func_name)
    ・・
277        if 'submit_config' in sig.parameters:
278            run_func_obj(submit_config=submit_config, **submit_config.run_func_kwargs)
279        else:
               run_func_obj(**submit_config.run_func_kwargs)
    ・・
310 def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None:
    ・・
343 return farm.submit(submit_config, host_run_dir)
     # local.py経由でrun_wrapper()を呼び出す

256 def run_wrapper(submit_config: SubmitConfig) -> None:
    """Wrap the actual run function call for handling logging, exceptions, typing, etc."""
     ・・
270 try:
         print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name))
    ・・
277     if 'submit_config' in sig.parameters:
278         run_func_obj(submit_config=submit_config, **submit_config.run_func_kwargs)
279     else:
280         run_func_obj(**submit_config.run_func_kwargs)
    ・・
282     print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time)))

画像の生成

generate_images() では、先程の retrained_networks のモジュールの load_networks() メソッドで学習済みのモデルがロードされます(21行目)。ここで、_GはGenerator(Mapping +
Synthesisモデル)、_DはDiscriminatorモデル、Gs は、Generatorの各inerationにおける値の指数移動平均を撮ったもので _Gより安定したモデルとして画像生成に使われています。
✳︎_G_Dモデルは画像生成では使用しません。

キーワードパラメータの編集を行った後、起動時に指定されたseed値分ループします(30行目)。
そしてMappingモデル入力する潜在変数(z)およびSynthesisモデルに入力するノイズの生成を行い(32〜34行目)、Generatorを起動します(35行目)。
Generatorの起動結果として生成されたFake画像の配列が返却されますので、起動時のパラメータ --result-dir で指定されたresultフォルダ(デフォルトは resultsになります)に、順次シード名を付与してRGBフォーマットで格納していきます(36行目)。

run_generator.py
 19 def generate_images(network_pkl, seeds, truncation_psi):
 20     print('Loading networks from "%s"...' % network_pkl)
 21     _G, _D, Gs = pretrained_networks.load_networks(network_pkl)
 22     noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]
 23
 24     Gs_kwargs = dnnlib.EasyDict()
 25     Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
 26     Gs_kwargs.randomize_noise = False
 27     if truncation_psi is not None:
 28         Gs_kwargs.truncation_psi = truncation_psi
 29
 30     for seed_idx, seed in enumerate(seeds):
 31         print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
 32         rnd = np.random.RandomState(seed)
 33         z = rnd.randn(1, *Gs.input_shape[1:]) # [minibatch, component]
 34         tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width]
 35         images = Gs.run(z, None, **Gs_kwargs) # [minibatch, height, width, channel]
 36         PIL.Image.fromarray(images[0], 'RGB').save(dnnlib.make_run_dir_path('seed%04d.png' % seed))

おわりに

以上が、StyleGAN2の画像生成コードの概要になります。

41
23
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
41
23