LoginSignup
36
21

More than 1 year has passed since last update.

画像生成モデルをマスタしよう! StyleGAN2のコードを解説(学習編)

Last updated at Posted at 2022-12-03

昨今 midjourneyStable DiffusionといったAIによる画像生成技術の話題が盛り上がっており、画像生成技術の急激な発展には驚かされます。

本稿では、GANモデルのブレイクスルーである StyleGAN2 の実装コードを見ることで、AIによる画像生成技術の理解をさらに深めることを目指し、GitHubに公開されているStyleGAN2の実装コードの解説を行います。 

今回は学習時のコードを解説します。

なお、StyleGAN2のアーキテクチャの詳細については以下の記事に記載されていますので、一緒に見ると理解が進むかと思います。

全てのコードを細かく説明するのではなく概ねの動作がわかるように解説してしています。
また、記載内容等に間違えがあれば随時ブラッシュアップしますのでご指摘をお願いします。

StyleGAN2のライブララリィ構成

  学習に関連するライブラリィ、モジュールの構成と概要は以下の通りです。

説明                     
学習バッチメイン run_training.py
学習用TFRecord生成バッチ dataset_tool.py
学習用パッケージ $training$
TFRecord読込みモジュール dataset.py
損失関数モジュール loss.py
学習用共通モジュール misc.py
StyleGAN2ネットワーク定義モジュール networks_stylegan2.py
学習ループ処理(学習制御)モジュール training_loop.py
DNN用パッケージ $dnnlib$
実行用パッケージ $submission$
内部用パッケージ $internal$
submitのラッパモジュール local.py
学習や生成のループを管理するためのヘルパーモジュール run_context.py
実行モジュール submit.py
Tensorflow用パッケージ $tflib$
画像や「活性化関数+バイアス融合」のオペレーションパッケージ $ops$
活性化関数+バイアスを融合させたオペレーション(CUDA) fused_bias_act.cu
活性化関数+バイアスを融合させたオペレーション fused_bias_act.py
各種画像処理(CUDA) upfirdn_2d.cu
各種画像処理モジュール upfirdn_2d.py
Tensorboardのヘルパーモジュール autosummary.py
Tensorflowのカスタムオペレーションモジュール custom_ops.py
Tensorflowのネットワークラッパーモジュール network.py
オプティマイザモジュール optimizer.py
各種Tensorflowユーティリティモジュール tfutil.py
ユーティリティモジュール util.py
各種メトリックスパッケージ(学習には直接は関係ないのでモジュール構成は省略) $metrics$

実装コードの解説

学習時のコードの流れを以下に順を追って説明します(Google Colab Proで実行して確認しています)。

学習の起動

まずは学習用バッチ run_training.py を起動します。 下記は、公式ホームページで顔画像を学習する際の起動コマンドのサンプルです。 今回は、Googleコラボで検証しましたので割り当て可能なGPU数は1ですので実際には「--num-gpus=1」で起動しています。

python run_training.py --num-gpus=8 --data-dir=~/datasets --config=config-f \
  --dataset=ffhq --mirror-augment=true

run_training.py が起動されるとまず同モジュール内の main() メソッドが呼ばれ(156行目)、起動パラメータの構文解析が行われた後に同モジュール内の run() メソッドが呼ばれます(163行目)。

run_traning.py
156 def main():
    parser = argparse.ArgumentParser(
        description='Train StyleGAN2.',
        epilog=_examples,
        formatter_class=argparse.RawDescriptionHelpFormatter
    )
     ・・
163     run(**vars(args))

run()メソッドでは、EasyDictクラス(ディクショナリクラスを継承したgetter/setter)を使って様々なパラメータ管理用のデータクラスを生成し、dnnlib.submit_run() メソッドを呼び出します(120行目)。

run_traning.py
 36 def run(dataset, data_dir, result_dir, config_id, num_gpus, total_kimg, gamma, mirror_augment, metrics):
 37   train     = EasyDict(run_func_name='training.training_loop.training_loop') # Options for training loop.
 38   G         = EasyDict(func_name='training.networks_stylegan2.G_main')       # Options for generator network.
 39   D         = EasyDict(func_name='training.networks_stylegan2.D_stylegan2')  # Options for discriminator network.
 40   G_opt     = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8)                  # Options for generator optimizer.
 41   D_opt     = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8)                  # Options for discriminator optimizer.
 42   G_loss    = EasyDict(func_name='training.loss.G_logistic_ns_pathreg')      # Options for generator loss.
 43   D_loss    = EasyDict(func_name='training.loss.D_logistic_r1')              # Options for discriminator loss.
・・
120 dnnlib.submit_run(**kwargs)

submit_run() では学習実行に向けた準備を行い、その後 farm.submit()メソッドを呼び出します。
farm.submit() では dnnlib.submission.submitモジュールの run_wrapper()を呼びます(343行目)。

dnnlib.submission.submit.py
310 def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None:
・・

343 return farm.submit(submit_config, host_run_dir)

run_wrapper()の278行目の run_func_obj() メソッドでは、 run_func_name で指定された名前から util.get_obj_by_name() メソッドを使用して動的に呼び出しメソッドを生成しますが(274行目)、これは、 run_training モジュールの run()メソッドの37行目 train = EasyDict(run_func_name='training.training_loop.training_loop')
で指定された training_loop モジュールの training_loop()メソッドが呼び出されることになります。

この training_loop() が学習全体をコントロールするメソッドになります。

dnnlib.submission.submit.py
256 def run_wrapper(submit_config: SubmitConfig) -> None:
・・
274        run_func_obj = util.get_obj_by_name(submit_config.run_func_name)
・・
277        if 'submit_config' in sig.parameters:
278            run_func_obj(submit_config=submit_config, **submit_config.run_func_kwargs)
279        else:
               run_func_obj(**submit_config.run_func_kwargs)
・・
 310 def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None:
    ・・
 343 return farm.submit(submit_config, host_run_dir)
     # local.py経由でrun_wrapper()を呼び出す

 256 def run_wrapper(submit_config: SubmitConfig) -> None:
    """Wrap the actual run function call for handling logging, exceptions, typing, etc."""
     ・・
 270 try:
         print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name))
         ・・
 277     if 'submit_config' in sig.parameters:
 278         run_func_obj(submit_config=submit_config, **submit_config.run_func_kwargs)
 279     else:
 280         run_func_obj(**submit_config.run_func_kwargs)
         ・・
 282     print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time)))

学習の準備

training_loop() メソッドでは、まず学習データ(TFRecord形式)のロードとネットワークモデルの生成を行います。ネットワークモデルがパラメータで指定されていた場合は、該当モデルをロードして追加学習をおこないます。(オリジナルのコードでは、このモデル名は training_loop() メソッドの引数 resume_pkl にハードコーディングで指定します)。

生成するネットワークモデルは G(Generator)、D(Discriminator)、Gs(Gの学習中の各iterationにおける値の指数移動平均を取ったもので画像生成時はこのモデルを使う)です。モデルの生成が完了したらGおよびDのモデル構造をプリント(159行目)します。

training_loop.py
131     resume_pkl              = None,     # Network pickle to resume training from, None = train from scratch.
・・
136     # Initialize dnnlib and TensorFlow.
137 tflib.init_tf(tf_config)
・・
140     # Load training set.
142     training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args)
・・
145     # Construct or load networks.
146     with tf.device('/gpu:0'):
147         if resume_pkl is None or resume_with_new_nets: # ← 新規学習の場合 
148             print('Constructing networks...')
149             G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args)
150             D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args)
151             Gs = G.clone('Gs')
152         if resume_pkl is not None:  # ← 追加学習の場合 
・・
159     G.print_layers(); D.print_layers()
・・

実態のネットワークモデルは、 training.networks_stylegan2 モジュールの定義に従って生成されますが、ここでは詳細な説明は割愛(定義のままであり冒頭に紹介しましたStyleGAN2を徹底解剖!の記事をみた方が理解できます)し、実際に生成されたネットワークモデルの印刷のみ紹介します。

G                             Params    OutputShape         WeightShape     
---                           ---       ---                 ---             
latents_in                    -         (?, 512)            -               
labels_in                     -         (?, 0)              -               
lod                           -         ()                  -               
dlatent_avg                   -         (512,)              -               
G_mapping/latents_in          -         (?, 512)            -               
G_mapping/labels_in           -         (?, 0)              -               
G_mapping/Normalize           -         (?, 512)            -               
G_mapping/Dense0              262656    (?, 512)            (512, 512)      
G_mapping/Dense1              262656    (?, 512)            (512, 512)      
G_mapping/Dense2              262656    (?, 512)            (512, 512)      
G_mapping/Dense3              262656    (?, 512)            (512, 512)      
G_mapping/Dense4              262656    (?, 512)            (512, 512)      
G_mapping/Dense5              262656    (?, 512)            (512, 512)      
G_mapping/Dense6              262656    (?, 512)            (512, 512)      
G_mapping/Dense7              262656    (?, 512)            (512, 512)      
G_mapping/Broadcast           -         (?, 12, 512)        -               
G_mapping/dlatents_out        -         (?, 12, 512)        -               
Truncation/Lerp               -         (?, 12, 512)        -               
G_synthesis/dlatents_in       -         (?, 12, 512)        -               
G_synthesis/4x4/Const         8192      (?, 512, 4, 4)      (1, 512, 4, 4)  
G_synthesis/4x4/Conv          2622465   (?, 512, 4, 4)      (3, 3, 512, 512)
G_synthesis/4x4/ToRGB         264195    (?, 3, 4, 4)        (1, 1, 512, 3)  
G_synthesis/8x8/Conv0_up      2622465   (?, 512, 8, 8)      (3, 3, 512, 512)
G_synthesis/8x8/Conv1         2622465   (?, 512, 8, 8)      (3, 3, 512, 512)
G_synthesis/8x8/Upsample      -         (?, 3, 8, 8)        -               
G_synthesis/8x8/ToRGB         264195    (?, 3, 8, 8)        (1, 1, 512, 3)  
G_synthesis/16x16/Conv0_up    2622465   (?, 512, 16, 16)    (3, 3, 512, 512)
G_synthesis/16x16/Conv1       2622465   (?, 512, 16, 16)    (3, 3, 512, 512)
G_synthesis/16x16/Upsample    -         (?, 3, 16, 16)      -               
G_synthesis/16x16/ToRGB       264195    (?, 3, 16, 16)      (1, 1, 512, 3)  
G_synthesis/32x32/Conv0_up    2622465   (?, 512, 32, 32)    (3, 3, 512, 512)
G_synthesis/32x32/Conv1       2622465   (?, 512, 32, 32)    (3, 3, 512, 512)
G_synthesis/32x32/Upsample    -         (?, 3, 32, 32)      -               
G_synthesis/32x32/ToRGB       264195    (?, 3, 32, 32)      (1, 1, 512, 3)  
G_synthesis/64x64/Conv0_up    2622465   (?, 512, 64, 64)    (3, 3, 512, 512)
G_synthesis/64x64/Conv1       2622465   (?, 512, 64, 64)    (3, 3, 512, 512)
G_synthesis/64x64/Upsample    -         (?, 3, 64, 64)      -               
G_synthesis/64x64/ToRGB       264195    (?, 3, 64, 64)      (1, 1, 512, 3)  
G_synthesis/128x128/Conv0_up  1442561   (?, 256, 128, 128)  (3, 3, 512, 256)
G_synthesis/128x128/Conv1     721409    (?, 256, 128, 128)  (3, 3, 256, 256)
G_synthesis/128x128/Upsample  -         (?, 3, 128, 128)    -               
G_synthesis/128x128/ToRGB     132099    (?, 3, 128, 128)    (1, 1, 256, 3)  
G_synthesis/images_out        -         (?, 3, 128, 128)    -               
G_synthesis/noise0            -         (1, 1, 4, 4)        -               
G_synthesis/noise1            -         (1, 1, 8, 8)        -               
G_synthesis/noise2            -         (1, 1, 8, 8)        -               
G_synthesis/noise3            -         (1, 1, 16, 16)      -               
G_synthesis/noise4            -         (1, 1, 16, 16)      -               
G_synthesis/noise5            -         (1, 1, 32, 32)      -               
G_synthesis/noise6            -         (1, 1, 32, 32)      -               
G_synthesis/noise7            -         (1, 1, 64, 64)      -               
G_synthesis/noise8            -         (1, 1, 64, 64)      -               
G_synthesis/noise9            -         (1, 1, 128, 128)    -               
G_synthesis/noise10           -         (1, 1, 128, 128)    -               
images_out                    -         (?, 3, 128, 128)    -               
---                           ---       ---                 ---             
Total                         29328669                                      


D                    Params    OutputShape         WeightShape     
---                  ---       ---                 ---             
images_in            -         (?, 3, 128, 128)    -               
labels_in            -         (?, 0)              -               
128x128/FromRGB      1024      (?, 256, 128, 128)  (1, 1, 3, 256)  
128x128/Conv0        590080    (?, 256, 128, 128)  (3, 3, 256, 256)
128x128/Conv1_down   1180160   (?, 512, 64, 64)    (3, 3, 256, 512)
128x128/Skip         131072    (?, 512, 64, 64)    (1, 1, 256, 512)
64x64/Conv0          2359808   (?, 512, 64, 64)    (3, 3, 512, 512)
64x64/Conv1_down     2359808   (?, 512, 32, 32)    (3, 3, 512, 512)
64x64/Skip           262144    (?, 512, 32, 32)    (1, 1, 512, 512)
32x32/Conv0          2359808   (?, 512, 32, 32)    (3, 3, 512, 512)
32x32/Conv1_down     2359808   (?, 512, 16, 16)    (3, 3, 512, 512)
32x32/Skip           262144    (?, 512, 16, 16)    (1, 1, 512, 512)
16x16/Conv0          2359808   (?, 512, 16, 16)    (3, 3, 512, 512)
16x16/Conv1_down     2359808   (?, 512, 8, 8)      (3, 3, 512, 512)
16x16/Skip           262144    (?, 512, 8, 8)      (1, 1, 512, 512)
8x8/Conv0            2359808   (?, 512, 8, 8)      (3, 3, 512, 512)
8x8/Conv1_down       2359808   (?, 512, 4, 4)      (3, 3, 512, 512)
8x8/Skip             262144    (?, 512, 4, 4)      (1, 1, 512, 512)
4x4/MinibatchStddev  -         (?, 513, 4, 4)      -               
4x4/Conv             2364416   (?, 512, 4, 4)      (3, 3, 513, 512)
4x4/Dense0           4194816   (?, 512)            (8192, 512)     
Output               513       (?, 1)              (512, 1)        
scores_out           -         (?, 1)              -               
---                  ---       ---                 ---             
Total                28389121                                      

本来は1024x1024の層までありますが、128x128の層までの学習モデルの例です

【Gsについて補足】
Gsは、本文中にも記載しましたが、Gの学習中の各iterationにおける値の指数移動平均です。

  • $ Gs \leftarrow \beta Gs + ( 1 - \beta) G $
  • $ \beta = 0.9978$ (学習設定により異なります。デフォルトでは、バッチサイズ32、1万サンプルで半減する設定と思われます)
  • Gモデル中のMappingモデル、Syntheisモデルのパラメータ両方を平均化します
  • GsはBackpropagationには関与しません
  • GsはGより安定した重みとして、学習後の画像生成に使用します
  • GsはマルチGPU設定に影響されません(Gの更新結果が全GPUに同期された後にGsの更新が動作します)
  • 教師あり学習で使われるStochastic Weight Averaginと等価の実装と思われます

149行目や150行目のネットワーク生成で使用される Networkモジュールは、Tensorflowでのネットワークモデルの抽象化クラスです。実際に生成されるクラスは、run_training モジュールの run() メソッドで指定した G = networks_stylegan2.G_main()D = networks_stylegan2.D_stylegan2() になります。 G_main() では、190行目で G_mapping()を呼び出してMappingモデルを、186行目でG_synthesis_stylegan2() を呼び出してSynthesisモデルをそれぞれ生成しています。

training.networks_stylegan2.py
151 def G_main(
・・
184 # Setup components.
185    if 'synthesis' not in components:
186        components.synthesis = tflib.Network('G_synthesis', func_name=globals()[synthesis_func], **kwargs)
187    num_layers = components.synthesis.input_shape[1]
188    dlatent_size = components.synthesis.input_shape[2]
189    if 'mapping' not in components:
190        components.mapping = tflib.Network('G_mapping', func_name=globals()[mapping_func], dlatent_broadcast=num_layers, **kwargs)
・・
251 def G_mapping(
・・
417 def G_synthesis_stylegan2(
・・
634 def D_stylegan2(

ネットワークモデルの生成が終わると、input情報の設定、オプティマイザの設定(176行目以降)を行います。

taining_loop.py
165   # Setup training inputs.
166    print('Building TensorFlow graph...')
167    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
168        lod_in               = tf.placeholder(tf.float32, name='lod_in', shape=[])
169        lrate_in             = tf.placeholder(tf.float32, name='lrate_in', shape=[])
170        minibatch_size_in    = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[])
171        minibatch_gpu_in     = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[])
172        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus)
173        Gs_beta              = 0.5 ** tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0
174
175    # Setup optimizers.
176    G_opt_args = dict(G_opt_args)
177    D_opt_args = dict(D_opt_args)
178    for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]:
179        args['minibatch_multiplier'] = minibatch_multiplier
180        args['learning_rate'] = lrate_in
181        if lazy_regularization:
182            mb_ratio = reg_interval / (reg_interval + 1)
183            args['learning_rate'] *= mb_ratio
184            if 'beta1' in args: args['beta1'] **= mb_ratio
185            if 'beta2' in args: args['beta2'] **= mb_ratio
186    G_opt = tflib.Optimizer(name='TrainG', **G_opt_args)
187    D_opt = tflib.Optimizer(name='TrainD', **D_opt_args)
188    G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args)
189    D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args)

学習時のオプティマイザは、tflib.Optimizerクラスで指定(28行目)した Adam ががデフォルトで使われます。

dnnlib.tflib.optimizer.py
class Optimizer:
・・
 26    def __init__(self,
 27       name:                   str             = "Train",                  # Name string that will appear in TensorFlow graph.
 28       tf_optimizer:           str             = "tf.train.AdamOptimizer", # Underlying optimizer class.
 29       learning_rate:          TfExpressionEx  = 0.001,                    # Learning rate. Can vary over time.
 30       minibatch_multiplier:   TfExpressionEx  = None,                     # Treat N consecutive minibatches as one by accumulating gradients.
・・

続いて各GPUごとにトレーニンググラフを作成します。
ここでは、学習データの取得と、損失関数(loss function)を登録します。
220行目の G_loss, G_reg = dnnlib.util.call_func_by_name() では、 traning_runモジュールの run() で指定された training.loss.G_logistic_ns_pathreg()メソッドが、 222行目のD_loss, D_reg = dnnlib.util.call_func_by_name() では同様に training.loss.D_logistic_r1() メソッドが呼び出されます。

training_loop.py
191    # Build training graph for each GPU.
192    data_fetch_ops = []
193    for gpu in range(num_gpus):
194        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):

196            # Create GPU-specific shadow copies of G and D.
197            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
198            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')

200            # Fetch training data via temporary variables.
201            with tf.name_scope('DataFetch'):
202                sched = training_schedule(cur_nimg=int(resume_kimg*1000), training_set=training_set, **sched_args)
203                reals_var = tf.Variable(name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape))
204                labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu, training_set.label_size]))
205                reals_write, labels_write = training_set.get_minibatch_tf()
206                reals_write, labels_write = process_reals(reals_write, labels_write, lod_in, mirror_augment, training_set.dynamic_range, drange_net)
207                reals_write = tf.concat([reals_write, reals_var[minibatch_gpu_in:]], axis=0)
208                labels_write = tf.concat([labels_write, labels_var[minibatch_gpu_in:]], axis=0)
209                data_fetch_ops += [tf.assign(reals_var, reals_write)]
210                data_fetch_ops += [tf.assign(labels_var, labels_write)]
211                reals_read = reals_var[:minibatch_gpu_in]
212                labels_read = labels_var[:minibatch_gpu_in]

214            # Evaluate loss functions.
215            lod_assign_ops = []
216            if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)]
217            if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)]
218            with tf.control_dependencies(lod_assign_ops):
219                with tf.name_scope('G_loss'):
220                    G_loss, G_reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args)
221                with tf.name_scope('D_loss'):
222                    D_loss, D_reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args)

224            # Register gradients.
225            if not lazy_regularization:
226                if G_reg is not None: G_loss += G_reg
227                if D_reg is not None: D_loss += D_reg
228            else:
229                if G_reg is not None: G_reg_opt.register_gradients(tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables)
230                if D_reg is not None: D_reg_opt.register_gradients(tf.reduce_mean(D_reg * D_reg_interval), D_gpu.trainables)
231            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
232            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)

続いてオペレーションの設定が行われます。
Generator、Discriminator、reg学習、regDiscriminatorの学習と損失計算のオペレーション設定が行われす。また、Gsのオペレーション設定が行われます。

training_loop.py
234 # Setup training ops.
235    data_fetch_op = tf.group(*data_fetch_ops)
236    G_train_op = G_opt.apply_updates()
237    D_train_op = D_opt.apply_updates()
238    G_reg_op = G_reg_opt.apply_updates(allow_no_op=True)
239    D_reg_op = D_reg_opt.apply_updates(allow_no_op=True)
240    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)

【補足】regについて
論文中に記載されている「Lazy regularization」の実装になります。
正則化項の更新は計算コストとメモリ使用量を削減から更新頻度を下げる(16ミニバッチごとに1回の実行でも影響がないことが分かっているため)戦略をとっています。

学習の実行

起動パラーメタで指定した学習枚数(total_kimg✖️1000枚)まで、指定したバッチサイズ(minibatchRepeat デフォルト=4(1GPUの時))で学習を繰り返します。この時、先ほど生成したオペレーションを実行します(tflib.run(G_train_op, feed_dict) 等)。

training_loop.py
266    while cur_nimg < total_kimg * 1000:
267        if dnnlib.RunContext.get().should_stop(): break

269        # Choose training parameters and configure training ops.
270        sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args)
271        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
272        training_set.configure(sched.minibatch_gpu, sched.lod)
273        if reset_opt_for_new_lod:
274            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod):
275                G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state()
276        prev_lod = sched.lod
・・
278        # Run training ops.
279        feed_dict = {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu}
280        for _repeat in range(minibatch_repeats):
281            rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus)
282            run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0)
283            run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0)
284            cur_nimg += sched.minibatch_size
285            running_mb_counter += 1

287            # Fast path without gradient accumulation.
288            if len(rounds) == 1:
289                tflib.run([G_train_op, data_fetch_op], feed_dict)
290                if run_G_reg:
291                    tflib.run(G_reg_op, feed_dict)
292                tflib.run([D_train_op, Gs_update_op], feed_dict)
293                if run_D_reg:
294                    tflib.run(D_reg_op, feed_dict)

296            # Slow path with gradient accumulation.
297            else:
298                for _round in rounds:
299                    tflib.run(G_train_op, feed_dict)
300                if run_G_reg:
301                    for _round in rounds:
302                        tflib.run(G_reg_op, feed_dict)
303                tflib.run(Gs_update_op, feed_dict)
304                for _round in rounds:
305                    tflib.run(data_fetch_op, feed_dict)
306                    tflib.run(D_train_op, feed_dict)
307                if run_D_reg:
308                    for _round in rounds:
309                        tflib.run(D_reg_op, feed_dict)

ミニバッチの終了都度、学習の終了判定(312行目)、tick判定(313行目)を行います。tickに該当すれば中間レポート(autosummary)の生成(322〜332行目)、現在学習されたモデルでFake画像を生成・保存(336,337行目)、モデルのスナップショット(339,340行目)の作成、生成された画像の品質評価等(341行目、デフォルトではfid50k(Fréchet Inception Distance)で評価)を行います。

training_loop.py
311       # Perform maintenance tasks once per tick.
312        done = (cur_nimg >= total_kimg * 1000)
313        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
314            cur_tick += 1
315            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
316            tick_start_nimg = cur_nimg
317            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
318            total_time = dnnlib.RunContext.get().get_time_since_start() + resume_time

320            # Report progress.
321            print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (
322                autosummary('Progress/tick', cur_tick),
323                autosummary('Progress/kimg', cur_nimg / 1000.0),
324                autosummary('Progress/lod', sched.lod),
325                autosummary('Progress/minibatch', sched.minibatch_size),
326                dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)),
327                autosummary('Timing/sec_per_tick', tick_time),
328                autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
329                autosummary('Timing/maintenance_sec', maintenance_time),
330                autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30)))
331            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
332            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

334            # Save snapshots.
335            if image_snapshot_ticks is not None and (cur_tick % image_snapshot_ticks == 0 or done):
336                grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu)
337                misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size)
338            if network_snapshot_ticks is not None and (cur_tick % network_snapshot_ticks == 0 or done):
339                pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000))
340                misc.save_pkl((G, D, Gs), pkl)
341                metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config)

343            # Update summaries and RunContext.
344            metrics.update_autosummaries()
345            tflib.autosummary.save_summaries(summary_log, cur_nimg)
346            dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg)
347            maintenance_time = dnnlib.RunContext.get().get_last_update_interval() - tick_time

学習の終了

学習が終了するとモデルを 'network-final.pklとして保存して終了します(350行目)。

training_loop.py
349    # Save final snapshot.
350    misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl'))

352    # All done.
353    summary_log.close()
354    training_set.close()

以上が、学習時のコードの流れの概要になります。

おわりに

細かい実装までは記載していませんが、おおよそのコードの実装はわかるのかと思います。
この学習済みモデルを用いて画像生成を行いますが、画像生成時のコードの説明は別の記事で行います。
 

36
21
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
36
21