昨今 midjourney
、 Stable Diffusion
といったAIによる画像生成技術の話題が盛り上がっており、画像生成技術の急激な発展には驚かされます。
本稿では、GANモデルのブレイクスルーである StyleGAN2
の実装コードを見ることで、AIによる画像生成技術の理解をさらに深めることを目指し、GitHubに公開されているStyleGAN2の実装コードの解説を行います。
今回は学習時のコードを解説します。
なお、StyleGAN2のアーキテクチャの詳細については以下の記事に記載されていますので、一緒に見ると理解が進むかと思います。
全てのコードを細かく説明するのではなく概ねの動作がわかるように解説してしています。
また、記載内容等に間違えがあれば随時ブラッシュアップしますのでご指摘をお願いします。
StyleGAN2のライブララリィ構成
学習に関連するライブラリィ、モジュールの構成と概要は以下の通りです。
説明 | ||||
---|---|---|---|---|
学習バッチメイン | run_training.py | |||
学習用TFRecord生成バッチ | dataset_tool.py | |||
学習用パッケージ | $training$ | |||
TFRecord読込みモジュール | dataset.py | |||
損失関数モジュール | loss.py | |||
学習用共通モジュール | misc.py | |||
StyleGAN2ネットワーク定義モジュール | networks_stylegan2.py | |||
学習ループ処理(学習制御)モジュール | training_loop.py | |||
DNN用パッケージ | $dnnlib$ | |||
実行用パッケージ | $submission$ | |||
内部用パッケージ | $internal$ | |||
submitのラッパモジュール | local.py | |||
学習や生成のループを管理するためのヘルパーモジュール | run_context.py | |||
実行モジュール | submit.py | |||
Tensorflow用パッケージ | $tflib$ | |||
画像や「活性化関数+バイアス融合」のオペレーションパッケージ | $ops$ | |||
活性化関数+バイアスを融合させたオペレーション(CUDA) | fused_bias_act.cu | |||
活性化関数+バイアスを融合させたオペレーション | fused_bias_act.py | |||
各種画像処理(CUDA) | upfirdn_2d.cu | |||
各種画像処理モジュール | upfirdn_2d.py | |||
Tensorboardのヘルパーモジュール | autosummary.py | | | ||
Tensorflowのカスタムオペレーションモジュール | custom_ops.py | |||
Tensorflowのネットワークラッパーモジュール | network.py | |||
オプティマイザモジュール | optimizer.py | |||
各種Tensorflowユーティリティモジュール | tfutil.py | |||
ユーティリティモジュール | util.py | |||
各種メトリックスパッケージ(学習には直接は関係ないのでモジュール構成は省略) | $metrics$ |
実装コードの解説
学習時のコードの流れを以下に順を追って説明します(Google Colab Proで実行して確認しています)。
学習の起動
まずは学習用バッチ run_training.py
を起動します。 下記は、公式ホームページで顔画像を学習する際の起動コマンドのサンプルです。 今回は、Googleコラボで検証しましたので割り当て可能なGPU数は1ですので実際には「--num-gpus=1」
で起動しています。
python run_training.py --num-gpus=8 --data-dir=~/datasets --config=config-f \
--dataset=ffhq --mirror-augment=true
run_training.py
が起動されるとまず同モジュール内の main()
メソッドが呼ばれ(156行目)、起動パラメータの構文解析が行われた後に同モジュール内の run()
メソッドが呼ばれます(163行目)。
156 def main():
parser = argparse.ArgumentParser(
description='Train StyleGAN2.',
epilog=_examples,
formatter_class=argparse.RawDescriptionHelpFormatter
)
・・
163 run(**vars(args))
run()メソッド
では、EasyDictクラス(ディクショナリクラスを継承したgetter/setter)を使って様々なパラメータ管理用のデータクラスを生成し、dnnlib.submit_run()
メソッドを呼び出します(120行目)。
36 def run(dataset, data_dir, result_dir, config_id, num_gpus, total_kimg, gamma, mirror_augment, metrics):
37 train = EasyDict(run_func_name='training.training_loop.training_loop') # Options for training loop.
38 G = EasyDict(func_name='training.networks_stylegan2.G_main') # Options for generator network.
39 D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2') # Options for discriminator network.
40 G_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for generator optimizer.
41 D_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) # Options for discriminator optimizer.
42 G_loss = EasyDict(func_name='training.loss.G_logistic_ns_pathreg') # Options for generator loss.
43 D_loss = EasyDict(func_name='training.loss.D_logistic_r1') # Options for discriminator loss.
・・
120 dnnlib.submit_run(**kwargs)
submit_run()
では学習実行に向けた準備を行い、その後 farm.submit()
メソッドを呼び出します。
farm.submit()
では dnnlib.submission.submit
モジュールの run_wrapper()
を呼びます(343行目)。
310 def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None:
・・
343 return farm.submit(submit_config, host_run_dir)
run_wrapper()
の278行目の run_func_obj()
メソッドでは、 run_func_name
で指定された名前から util.get_obj_by_name()
メソッドを使用して動的に呼び出しメソッドを生成しますが(274行目)、これは、 run_training
モジュールの run()
メソッドの37行目 train = EasyDict(run_func_name='training.training_loop.training_loop')
で指定された training_loop
モジュールの training_loop()
メソッドが呼び出されることになります。
この training_loop()
が学習全体をコントロールするメソッドになります。
256 def run_wrapper(submit_config: SubmitConfig) -> None:
・・
274 run_func_obj = util.get_obj_by_name(submit_config.run_func_name)
・・
277 if 'submit_config' in sig.parameters:
278 run_func_obj(submit_config=submit_config, **submit_config.run_func_kwargs)
279 else:
run_func_obj(**submit_config.run_func_kwargs)
・・
310 def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None:
・・
343 return farm.submit(submit_config, host_run_dir)
# local.py経由でrun_wrapper()を呼び出す
256 def run_wrapper(submit_config: SubmitConfig) -> None:
"""Wrap the actual run function call for handling logging, exceptions, typing, etc."""
・・
270 try:
print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name))
・・
277 if 'submit_config' in sig.parameters:
278 run_func_obj(submit_config=submit_config, **submit_config.run_func_kwargs)
279 else:
280 run_func_obj(**submit_config.run_func_kwargs)
・・
282 print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time)))
学習の準備
training_loop()
メソッドでは、まず学習データ(TFRecord形式)のロードとネットワークモデルの生成を行います。ネットワークモデルがパラメータで指定されていた場合は、該当モデルをロードして追加学習をおこないます。(オリジナルのコードでは、このモデル名は training_loop()
メソッドの引数 resume_pkl
にハードコーディングで指定します)。
生成するネットワークモデルは G(Generator)、D(Discriminator)、Gs(Gの学習中の各iterationにおける値の指数移動平均を取ったもので画像生成時はこのモデルを使う)です。モデルの生成が完了したらGおよびDのモデル構造をプリント(159行目)します。
131 resume_pkl = None, # Network pickle to resume training from, None = train from scratch.
・・
136 # Initialize dnnlib and TensorFlow.
137 tflib.init_tf(tf_config)
・・
140 # Load training set.
142 training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args)
・・
145 # Construct or load networks.
146 with tf.device('/gpu:0'):
147 if resume_pkl is None or resume_with_new_nets: # ← 新規学習の場合
148 print('Constructing networks...')
149 G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args)
150 D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args)
151 Gs = G.clone('Gs')
152 if resume_pkl is not None: # ← 追加学習の場合
・・
159 G.print_layers(); D.print_layers()
・・
実態のネットワークモデルは、 training.networks_stylegan2
モジュールの定義に従って生成されますが、ここでは詳細な説明は割愛(定義のままであり冒頭に紹介しましたStyleGAN2を徹底解剖!の記事をみた方が理解できます)し、実際に生成されたネットワークモデルの印刷のみ紹介します。
G Params OutputShape WeightShape
--- --- --- ---
latents_in - (?, 512) -
labels_in - (?, 0) -
lod - () -
dlatent_avg - (512,) -
G_mapping/latents_in - (?, 512) -
G_mapping/labels_in - (?, 0) -
G_mapping/Normalize - (?, 512) -
G_mapping/Dense0 262656 (?, 512) (512, 512)
G_mapping/Dense1 262656 (?, 512) (512, 512)
G_mapping/Dense2 262656 (?, 512) (512, 512)
G_mapping/Dense3 262656 (?, 512) (512, 512)
G_mapping/Dense4 262656 (?, 512) (512, 512)
G_mapping/Dense5 262656 (?, 512) (512, 512)
G_mapping/Dense6 262656 (?, 512) (512, 512)
G_mapping/Dense7 262656 (?, 512) (512, 512)
G_mapping/Broadcast - (?, 12, 512) -
G_mapping/dlatents_out - (?, 12, 512) -
Truncation/Lerp - (?, 12, 512) -
G_synthesis/dlatents_in - (?, 12, 512) -
G_synthesis/4x4/Const 8192 (?, 512, 4, 4) (1, 512, 4, 4)
G_synthesis/4x4/Conv 2622465 (?, 512, 4, 4) (3, 3, 512, 512)
G_synthesis/4x4/ToRGB 264195 (?, 3, 4, 4) (1, 1, 512, 3)
G_synthesis/8x8/Conv0_up 2622465 (?, 512, 8, 8) (3, 3, 512, 512)
G_synthesis/8x8/Conv1 2622465 (?, 512, 8, 8) (3, 3, 512, 512)
G_synthesis/8x8/Upsample - (?, 3, 8, 8) -
G_synthesis/8x8/ToRGB 264195 (?, 3, 8, 8) (1, 1, 512, 3)
G_synthesis/16x16/Conv0_up 2622465 (?, 512, 16, 16) (3, 3, 512, 512)
G_synthesis/16x16/Conv1 2622465 (?, 512, 16, 16) (3, 3, 512, 512)
G_synthesis/16x16/Upsample - (?, 3, 16, 16) -
G_synthesis/16x16/ToRGB 264195 (?, 3, 16, 16) (1, 1, 512, 3)
G_synthesis/32x32/Conv0_up 2622465 (?, 512, 32, 32) (3, 3, 512, 512)
G_synthesis/32x32/Conv1 2622465 (?, 512, 32, 32) (3, 3, 512, 512)
G_synthesis/32x32/Upsample - (?, 3, 32, 32) -
G_synthesis/32x32/ToRGB 264195 (?, 3, 32, 32) (1, 1, 512, 3)
G_synthesis/64x64/Conv0_up 2622465 (?, 512, 64, 64) (3, 3, 512, 512)
G_synthesis/64x64/Conv1 2622465 (?, 512, 64, 64) (3, 3, 512, 512)
G_synthesis/64x64/Upsample - (?, 3, 64, 64) -
G_synthesis/64x64/ToRGB 264195 (?, 3, 64, 64) (1, 1, 512, 3)
G_synthesis/128x128/Conv0_up 1442561 (?, 256, 128, 128) (3, 3, 512, 256)
G_synthesis/128x128/Conv1 721409 (?, 256, 128, 128) (3, 3, 256, 256)
G_synthesis/128x128/Upsample - (?, 3, 128, 128) -
G_synthesis/128x128/ToRGB 132099 (?, 3, 128, 128) (1, 1, 256, 3)
G_synthesis/images_out - (?, 3, 128, 128) -
G_synthesis/noise0 - (1, 1, 4, 4) -
G_synthesis/noise1 - (1, 1, 8, 8) -
G_synthesis/noise2 - (1, 1, 8, 8) -
G_synthesis/noise3 - (1, 1, 16, 16) -
G_synthesis/noise4 - (1, 1, 16, 16) -
G_synthesis/noise5 - (1, 1, 32, 32) -
G_synthesis/noise6 - (1, 1, 32, 32) -
G_synthesis/noise7 - (1, 1, 64, 64) -
G_synthesis/noise8 - (1, 1, 64, 64) -
G_synthesis/noise9 - (1, 1, 128, 128) -
G_synthesis/noise10 - (1, 1, 128, 128) -
images_out - (?, 3, 128, 128) -
--- --- --- ---
Total 29328669
D Params OutputShape WeightShape
--- --- --- ---
images_in - (?, 3, 128, 128) -
labels_in - (?, 0) -
128x128/FromRGB 1024 (?, 256, 128, 128) (1, 1, 3, 256)
128x128/Conv0 590080 (?, 256, 128, 128) (3, 3, 256, 256)
128x128/Conv1_down 1180160 (?, 512, 64, 64) (3, 3, 256, 512)
128x128/Skip 131072 (?, 512, 64, 64) (1, 1, 256, 512)
64x64/Conv0 2359808 (?, 512, 64, 64) (3, 3, 512, 512)
64x64/Conv1_down 2359808 (?, 512, 32, 32) (3, 3, 512, 512)
64x64/Skip 262144 (?, 512, 32, 32) (1, 1, 512, 512)
32x32/Conv0 2359808 (?, 512, 32, 32) (3, 3, 512, 512)
32x32/Conv1_down 2359808 (?, 512, 16, 16) (3, 3, 512, 512)
32x32/Skip 262144 (?, 512, 16, 16) (1, 1, 512, 512)
16x16/Conv0 2359808 (?, 512, 16, 16) (3, 3, 512, 512)
16x16/Conv1_down 2359808 (?, 512, 8, 8) (3, 3, 512, 512)
16x16/Skip 262144 (?, 512, 8, 8) (1, 1, 512, 512)
8x8/Conv0 2359808 (?, 512, 8, 8) (3, 3, 512, 512)
8x8/Conv1_down 2359808 (?, 512, 4, 4) (3, 3, 512, 512)
8x8/Skip 262144 (?, 512, 4, 4) (1, 1, 512, 512)
4x4/MinibatchStddev - (?, 513, 4, 4) -
4x4/Conv 2364416 (?, 512, 4, 4) (3, 3, 513, 512)
4x4/Dense0 4194816 (?, 512) (8192, 512)
Output 513 (?, 1) (512, 1)
scores_out - (?, 1) -
--- --- --- ---
Total 28389121
本来は1024x1024の層までありますが、128x128の層までの学習モデルの例です
【Gsについて補足】
Gsは、本文中にも記載しましたが、Gの学習中の各iterationにおける値の指数移動平均です。
- $ Gs \leftarrow \beta Gs + ( 1 - \beta) G $
- $ \beta = 0.9978$ (学習設定により異なります。デフォルトでは、バッチサイズ32、1万サンプルで半減する設定と思われます)
- Gモデル中のMappingモデル、Syntheisモデルのパラメータ両方を平均化します
- GsはBackpropagationには関与しません
- GsはGより安定した重みとして、学習後の画像生成に使用します
- GsはマルチGPU設定に影響されません(Gの更新結果が全GPUに同期された後にGsの更新が動作します)
- 教師あり学習で使われるStochastic Weight Averaginと等価の実装と思われます
149行目や150行目のネットワーク生成で使用される Network
モジュールは、Tensorflowでのネットワークモデルの抽象化クラスです。実際に生成されるクラスは、run_training
モジュールの run()
メソッドで指定した G = networks_stylegan2.G_main()
、 D = networks_stylegan2.D_stylegan2()
になります。 G_main()
では、190行目で G_mapping()
を呼び出してMappingモデルを、186行目でG_synthesis_stylegan2()
を呼び出してSynthesisモデルをそれぞれ生成しています。
151 def G_main(
・・
184 # Setup components.
185 if 'synthesis' not in components:
186 components.synthesis = tflib.Network('G_synthesis', func_name=globals()[synthesis_func], **kwargs)
187 num_layers = components.synthesis.input_shape[1]
188 dlatent_size = components.synthesis.input_shape[2]
189 if 'mapping' not in components:
190 components.mapping = tflib.Network('G_mapping', func_name=globals()[mapping_func], dlatent_broadcast=num_layers, **kwargs)
・・
251 def G_mapping(
・・
417 def G_synthesis_stylegan2(
・・
634 def D_stylegan2(
ネットワークモデルの生成が終わると、input情報の設定、オプティマイザの設定(176行目以降)を行います。
165 # Setup training inputs.
166 print('Building TensorFlow graph...')
167 with tf.name_scope('Inputs'), tf.device('/cpu:0'):
168 lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
169 lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
170 minibatch_size_in = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[])
171 minibatch_gpu_in = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[])
172 minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus)
173 Gs_beta = 0.5 ** tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0
174
175 # Setup optimizers.
176 G_opt_args = dict(G_opt_args)
177 D_opt_args = dict(D_opt_args)
178 for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]:
179 args['minibatch_multiplier'] = minibatch_multiplier
180 args['learning_rate'] = lrate_in
181 if lazy_regularization:
182 mb_ratio = reg_interval / (reg_interval + 1)
183 args['learning_rate'] *= mb_ratio
184 if 'beta1' in args: args['beta1'] **= mb_ratio
185 if 'beta2' in args: args['beta2'] **= mb_ratio
186 G_opt = tflib.Optimizer(name='TrainG', **G_opt_args)
187 D_opt = tflib.Optimizer(name='TrainD', **D_opt_args)
188 G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args)
189 D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args)
学習時のオプティマイザは、tflib.Optimizer
クラスで指定(28行目)した Adam
ががデフォルトで使われます。
class Optimizer:
・・
26 def __init__(self,
27 name: str = "Train", # Name string that will appear in TensorFlow graph.
28 tf_optimizer: str = "tf.train.AdamOptimizer", # Underlying optimizer class.
29 learning_rate: TfExpressionEx = 0.001, # Learning rate. Can vary over time.
30 minibatch_multiplier: TfExpressionEx = None, # Treat N consecutive minibatches as one by accumulating gradients.
・・
続いて各GPUごとにトレーニンググラフを作成します。
ここでは、学習データの取得と、損失関数(loss function)を登録します。
220行目の G_loss, G_reg = dnnlib.util.call_func_by_name()
では、 traning_run
モジュールの run()
で指定された training.loss.G_logistic_ns_pathreg()
メソッドが、 222行目のD_loss, D_reg = dnnlib.util.call_func_by_name()
では同様に training.loss.D_logistic_r1()
メソッドが呼び出されます。
191 # Build training graph for each GPU.
192 data_fetch_ops = []
193 for gpu in range(num_gpus):
194 with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
196 # Create GPU-specific shadow copies of G and D.
197 G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
198 D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
200 # Fetch training data via temporary variables.
201 with tf.name_scope('DataFetch'):
202 sched = training_schedule(cur_nimg=int(resume_kimg*1000), training_set=training_set, **sched_args)
203 reals_var = tf.Variable(name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape))
204 labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu, training_set.label_size]))
205 reals_write, labels_write = training_set.get_minibatch_tf()
206 reals_write, labels_write = process_reals(reals_write, labels_write, lod_in, mirror_augment, training_set.dynamic_range, drange_net)
207 reals_write = tf.concat([reals_write, reals_var[minibatch_gpu_in:]], axis=0)
208 labels_write = tf.concat([labels_write, labels_var[minibatch_gpu_in:]], axis=0)
209 data_fetch_ops += [tf.assign(reals_var, reals_write)]
210 data_fetch_ops += [tf.assign(labels_var, labels_write)]
211 reals_read = reals_var[:minibatch_gpu_in]
212 labels_read = labels_var[:minibatch_gpu_in]
214 # Evaluate loss functions.
215 lod_assign_ops = []
216 if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)]
217 if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)]
218 with tf.control_dependencies(lod_assign_ops):
219 with tf.name_scope('G_loss'):
220 G_loss, G_reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args)
221 with tf.name_scope('D_loss'):
222 D_loss, D_reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args)
224 # Register gradients.
225 if not lazy_regularization:
226 if G_reg is not None: G_loss += G_reg
227 if D_reg is not None: D_loss += D_reg
228 else:
229 if G_reg is not None: G_reg_opt.register_gradients(tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables)
230 if D_reg is not None: D_reg_opt.register_gradients(tf.reduce_mean(D_reg * D_reg_interval), D_gpu.trainables)
231 G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
232 D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
続いてオペレーションの設定が行われます。
Generator、Discriminator、reg学習、regDiscriminatorの学習と損失計算のオペレーション設定が行われす。また、Gsのオペレーション設定が行われます。
234 # Setup training ops.
235 data_fetch_op = tf.group(*data_fetch_ops)
236 G_train_op = G_opt.apply_updates()
237 D_train_op = D_opt.apply_updates()
238 G_reg_op = G_reg_opt.apply_updates(allow_no_op=True)
239 D_reg_op = D_reg_opt.apply_updates(allow_no_op=True)
240 Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)
【補足】regについて
論文中に記載されている「Lazy regularization」の実装になります。
正則化項の更新は計算コストとメモリ使用量を削減から更新頻度を下げる(16ミニバッチごとに1回の実行でも影響がないことが分かっているため)戦略をとっています。
学習の実行
起動パラーメタで指定した学習枚数(total_kimg✖️1000枚)まで、指定したバッチサイズ(minibatchRepeat デフォルト=4(1GPUの時))で学習を繰り返します。この時、先ほど生成したオペレーションを実行します(tflib.run(G_train_op, feed_dict)
等)。
266 while cur_nimg < total_kimg * 1000:
267 if dnnlib.RunContext.get().should_stop(): break
269 # Choose training parameters and configure training ops.
270 sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args)
271 assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
272 training_set.configure(sched.minibatch_gpu, sched.lod)
273 if reset_opt_for_new_lod:
274 if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod):
275 G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state()
276 prev_lod = sched.lod
・・
278 # Run training ops.
279 feed_dict = {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu}
280 for _repeat in range(minibatch_repeats):
281 rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus)
282 run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0)
283 run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0)
284 cur_nimg += sched.minibatch_size
285 running_mb_counter += 1
287 # Fast path without gradient accumulation.
288 if len(rounds) == 1:
289 tflib.run([G_train_op, data_fetch_op], feed_dict)
290 if run_G_reg:
291 tflib.run(G_reg_op, feed_dict)
292 tflib.run([D_train_op, Gs_update_op], feed_dict)
293 if run_D_reg:
294 tflib.run(D_reg_op, feed_dict)
296 # Slow path with gradient accumulation.
297 else:
298 for _round in rounds:
299 tflib.run(G_train_op, feed_dict)
300 if run_G_reg:
301 for _round in rounds:
302 tflib.run(G_reg_op, feed_dict)
303 tflib.run(Gs_update_op, feed_dict)
304 for _round in rounds:
305 tflib.run(data_fetch_op, feed_dict)
306 tflib.run(D_train_op, feed_dict)
307 if run_D_reg:
308 for _round in rounds:
309 tflib.run(D_reg_op, feed_dict)
ミニバッチの終了都度、学習の終了判定(312行目)、tick判定(313行目)を行います。tickに該当すれば中間レポート(autosummary)の生成(322〜332行目)、現在学習されたモデルでFake画像を生成・保存(336,337行目)、モデルのスナップショット(339,340行目)の作成、生成された画像の品質評価等(341行目、デフォルトではfid50k(Fréchet Inception Distance)で評価)を行います。
311 # Perform maintenance tasks once per tick.
312 done = (cur_nimg >= total_kimg * 1000)
313 if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
314 cur_tick += 1
315 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
316 tick_start_nimg = cur_nimg
317 tick_time = dnnlib.RunContext.get().get_time_since_last_update()
318 total_time = dnnlib.RunContext.get().get_time_since_start() + resume_time
320 # Report progress.
321 print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (
322 autosummary('Progress/tick', cur_tick),
323 autosummary('Progress/kimg', cur_nimg / 1000.0),
324 autosummary('Progress/lod', sched.lod),
325 autosummary('Progress/minibatch', sched.minibatch_size),
326 dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)),
327 autosummary('Timing/sec_per_tick', tick_time),
328 autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
329 autosummary('Timing/maintenance_sec', maintenance_time),
330 autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30)))
331 autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
332 autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))
334 # Save snapshots.
335 if image_snapshot_ticks is not None and (cur_tick % image_snapshot_ticks == 0 or done):
336 grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu)
337 misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size)
338 if network_snapshot_ticks is not None and (cur_tick % network_snapshot_ticks == 0 or done):
339 pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000))
340 misc.save_pkl((G, D, Gs), pkl)
341 metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config)
343 # Update summaries and RunContext.
344 metrics.update_autosummaries()
345 tflib.autosummary.save_summaries(summary_log, cur_nimg)
346 dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg)
347 maintenance_time = dnnlib.RunContext.get().get_last_update_interval() - tick_time
学習の終了
学習が終了するとモデルを 'network-final.pkl
として保存して終了します(350行目)。
349 # Save final snapshot.
350 misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl'))
352 # All done.
353 summary_log.close()
354 training_set.close()
以上が、学習時のコードの流れの概要になります。
おわりに
細かい実装までは記載していませんが、おおよそのコードの実装はわかるのかと思います。
この学習済みモデルを用いて画像生成を行いますが、画像生成時のコードの説明は別の記事で行います。