1. はじめに
NVIDIA/apex(のamp)は、FP16の演算を前提とした(順および逆)伝播計算等を支援する関数群である。これにより、2倍から4倍程度早くなる。そのソースコードを読んだメモである。NVIDIA/apexでは、下記の処理を行っている。
- Python層での関数ラッピング
- C++関数のPythonへのマッピング
なお、NVIDIA/apexは、OSS公開後、かなりのコード書き換えが行われつつあり、インターフェースがよりシンプルに変わりつつある。ここで記載するのは、2019年10月現在のものである。また、ampは、PyTorchへの取り込みの議論が8月末より行われている。このため、PyTorch 1.4以降で本体に取り込まれる見込みである。(2020年6月追記:本作業は進行中のようであり1.6を目標としているようである。)
また、apex配下で、プロファイラ(pyprof)等も作られているようではある。しかし、nvprof以上のメリットはないと思われるので言及しない。
2. Python関数のラッピングについて
2.1. apex.amp (FP32/FP16の混合計算)
FP32/FP16の混合演算(TensorCore)で、高速に処理するためのしくみが、apex.ampである。これによって、Float32からFloat16に型変換する。また、この変換処理の一つとして、Float16で損失計算の精度維持を行うために必要な損失スケーリング(Loss Scaling)の設定も行われる。Float16の精度補正のためなので、Float32の演算では、モデルや最適化関数の型変更はされない。なお、apex.ampは、関数のデータ型を置き換えるだけであり、Float16の際に、Tensor Coreで演算するか否かはPyTorch側の設定となる。
apex.ampを使うためには、通常の処理に対して、3行追加する必要がある。
# Added after model and optimizer construction
model, optimizer = amp.initialize(model, optimizer, flags...)
...
# loss.backward() changed to:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
- 1項目目(amp.initialize)は、
- ampのモード等の設定(O0からO3等の設定を
_amp_state
に書き込む)を行う。損失係数の初期化(LossScalingクラス)もその延長で行う - モデル(model)および最適化関数(optimizer)の各層がどの型(float16やfloat32等)で行うかのカスタマイズを行っている。
- ampのモード等の設定(O0からO3等の設定を
- 2項目目(amp.scale_loss)は、逆伝搬の計算の前後に損失係数処理(ロススケーリング)を追加している。
ここでは、1項目目および2項目目について説明する。
はじめに、動作環境の確認を行う。たとえば、cudnnが使えるかなどを確認する。
if not torch.backends.cudnn.enabled:
raise RuntimeError(
"Amp requires torch.backends.cudnn.enabled = True")
次に、モデルのラッピング(置き換え)を行う。モデルのデータ型変換および、入出力のデータ変更等である。functools.partialや、独自に定義したapplier等によりデータ型変換を行う。
if properties.cast_model_type:
if properties.keep_batchnorm_fp32:
for model in models:
convert_network(model, properties.cast_model_type)
else:
for model in models:
model.to(properties.cast_model_type)
input_caster = functools.partial(to_type, properties.cast_model_type)
if cast_model_outputs is not None:
output_caster = functools.partial(to_type, cast_model_outputs)
else:
output_caster = functools.partial(to_type, torch.float32)
for model in models:
# Patch the forward method to cast incoming data to the correct type, and
# outgoing data to float32, so "the user never needs to call .half()."
# I like writing things explicitly more than decorators.
def patch_forward(old_fwd):
def new_fwd(*args, **kwargs):
output = old_fwd(*applier(args, input_caster),
**applier(kwargs, input_caster))
return applier(output, output_caster)
return new_fwd
model.forward = patch_forward(model.forward)
# State dict trick to recast any preexisting per-param state tensors
for optimizer in optimizers:
optimizer.load_state_dict(optimizer.state_dict())
# patch model.state_dict() to return float32 params
for model in models:
for module in model.modules():
module._register_state_dict_hook(O2StateDictHook(functools.partial(to_type, torch.float32)))
elif cast_model_outputs is not None:
output_caster = functools.partial(to_type, cast_model_outputs)
for model in models:
def patch_forward(old_fwd):
def new_fwd(*args, **kwargs):
output = old_fwd(*args, **kwargs)
return applier(output, output_caster)
return new_fwd
model.forward = patch_forward(model.forward)
さらに、O2の初期化処理では、最適化(optimize)ステップは、キャスティング(データ型変換)されずもとの型で計算されるよう設定される。具体的には以下の箇所である。これにより、master weightの精度がFP32で維持される。
if properties.patch_torch_functions:
# handle is unused here. It's accessible later through a global value anyway.
handle = amp_init(loss_scale=properties.loss_scale, verbose=(_amp_state.verbosity == 2))
for optimizer in optimizers:
# Disable Amp casting for the optimizer step, because it should only be
# applied to FP32 master params anyway.
def patch_step(old_step):
def new_step(*args, **kwargs):
with disable_casts():
output = old_step(*args, **kwargs)
return output
return new_step
optimizer.step = patch_step(optimizer.step)
また、損失係数の考慮(Loss Scaling)が必要になる。amp_initで、AmpHandlerクラスの中の要素として、LossScaler内で、_loss_scaleにパラメータを設定する。デフォルトは、65536(=2^16)である。
class LossScaler(object):
warned_no_fused_kernel = False
warned_unscaling_non_fp32_grad = False
has_fused_kernel = False
def __init__(self,
loss_scale,
init_scale=2.**16,
scale_factor=2.,
scale_window=2000,
min_loss_scale=None,
max_loss_scale=2.**24):
if loss_scale == "dynamic":
self.dynamic = True
self._loss_scale = min(max_loss_scale, init_scale)
else:
self.dynamic = False
self._loss_scale = loss_scale
self._max_loss_scale = max_loss_scale
self._min_loss_scale = min_loss_scale
self._scale_seq_len = scale_window
self._unskipped = 0
self._has_overflow = False
self._overflow_buf = torch.cuda.IntTensor([0])
if multi_tensor_applier.available:
import amp_C
LossScaler.has_fused_kernel = multi_tensor_applier.available
LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale
LossScaler.multi_tensor_axpby_cuda = amp_C.multi_tensor_axpby
else:
if not LossScaler.warned_no_fused_kernel:
maybe_print(
"Warning: multi_tensor_applier fused unscale kernel is unavailable, "
"possibly because apex was installed without --cuda_ext --cpp_ext. "
"Using Python fallback. Original ImportError was: " +
repr(multi_tensor_applier.import_err),
True)
LossScaler.has_fused_kernel = False
LossScaler.warned_no_fused_kernel = True
損失係数の計算は、amp.scale_loss
で行う。この関数は、contextlibを構成するwith
で呼ばれることから、__enter__
, __exit__
相当の処理およびyield以下の設定処理からなる。特に、loss_scale
を設定している場合、以下が呼び出され損失係数が考慮される。
yield (loss.float())*loss_scale
損失係数LossScalingの演算は、上記yieldの後に、LossScalerクラスで行っている。このLossScaler関数は、apex.amp.scaler.LossScaler
である。損失係数は、update_scale()メソッドにより更新される。
なお、モデル等の逆伝搬(grad値の要素)に、Inf/NAN等があるかをチェックする。あれば、損失係数(LossScaling)の変更を行う。
さて、ampの各オプションにおける最適化機能は以下の通りになっている。
opt_level | O0 | O1 | O2 | O3 |
---|---|---|---|---|
cast_model_type | torch.float32 | None | torch.float16 | torch.float16 |
patch_torch_functions | False | True | False | False |
keep_batchnorm_fp32 | None | None | True | False |
master_weights | False | None | True | False |
loss_scale | 1.0 | dynamic | dynamic | 1.0 |
ここで、上記のパラメータは以下の通りである。ロススケーリングは、O1/O2でしか動かないことに注意を要する。 |
- 最適化パラメータ
- O0 FP32での演算を示す
- O1 TensorCoreを用いたFP32/FP16混合の演算を示す。(NVIDIAではお勧めの設定)
- O2 ほとんどFP16の混合演算 master_weightsがFP32で行われる以外はFP16で行われる
- O3 FP16での演算を示す。
- 各パラメータ
- cast_model_typeは、モデルパラメータをどの型に変換するかをしていするパラメータである。
- patch_torch_functionsは、メソッド等をTensorCore向けに変換するかを設定するパラメータである。
- keep_batch_norm_fp32は、バッチ正規化の演算を、どの精度で行うかを指定するパラメータである。
- master_weightsは、演算時の重みをFP32で行うかを設定するパラメータである。
- loss_scaleは、損失スケーリングを行うか否かを設定するパラメータである。
3. C++関数のPythonへのマッピングについて
3.1. PyTorch側の実装
PyTorchでは、PYBIND11をベースにC++拡張がしやすい仕組みが入っている。具体的には、setuptools.Extensionの型として、C++/CUDAの関数を定義するヘルパークラスが、CppExtension/CUDAExtensionとして提供されている。
3.1.1. Python側
-
class BuildExtension
- C++ファイルで用いるTORCH_EXTENSION_NAME変数の置き換え処理スクリプトを、ninja用に生成するクラス。
3.1.2. C++側
-
torch/extension.h
- このヘッダーファイルで、インターフェース定義を取り込んで、関数定義を行う。
3.2. NVIDIA/apexでの実装
3.2.1. Python側
-
setup.py
- 各ライブラリのPythonでのエントリーポイントを定義する。ext_module.appendが定義している箇所である。以下のモジュールが定義されており、Pythonからimportで呼び出すことが出来る。
- apex_C テンソルを1次元化する関数(torch._utils._flatten_dense_tensorsを高速化したもの)
- amp_C
- fused_adam_cuda
- fused_layer_norm_cuda
- sync_bn
- 各ライブラリのPythonでのエントリーポイントを定義する。ext_module.appendが定義している箇所である。以下のモジュールが定義されており、Pythonからimportで呼び出すことが出来る。
例えば、以下のような定義を行っている。
ext_modules.append(
CUDAExtension(name='amp_C',
sources=['csrc/amp_C_frontend.cpp',
'csrc/multi_tensor_sgd_kernel.cu',
'csrc/multi_tensor_scale_kernel.cu',
'csrc/multi_tensor_axpby_kernel.cu',
'csrc/multi_tensor_l2norm_kernel.cu',
'csrc/multi_tensor_lamb_stage_1.cu',
'csrc/multi_tensor_lamb_stage_2.cu',
'csrc/multi_tensor_adam.cu',
'csrc/multi_tensor_novograd.cu',
'csrc/multi_tensor_lamb.cu'],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-lineinfo',
'-O3',
# '--resource-usage',
'--use_fast_math'] + version_dependent_macros}))
3.2.2. C++側
- C++側でのモジュール定義
- (apex_C)apex/csrc/flatten_unflatten.cpp
- (amp_C)apex/csrc/amp_C_frontend.cpp
- (fused_adam_cuda)apex/csrc/fused_adam_cuda.cpp
- (fused_layer_norm_cuda)apex/csrc/layer_norm_cuda.cpp
- (sync_bn)apex/csrc/syncbn.cpp
例えば、以下のように定義する。なお、TORCH_EXTENSION_NAMEは、setup.pyの延長で呼び出される_define_torch_extension_name
で設定される。
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors");
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
"Fused SGD optimizer for list of contiguous tensors");
m.def("multi_tensor_axpby", &multi_tensor_axpby_cuda,
"out = a*x + b*y for a list of contiguous tensors");
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors");
m.def("multi_tensor_lamb_stage1_cuda", &multi_tensor_lamb_stage1_cuda,
"Computes update part of LAMB optimizer");
m.def("multi_tensor_lamb_stage2_cuda", &multi_tensor_lamb_stage2_cuda,
"Completes application of gradient to parameters for LAMB optimizer");
m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer");
m.def("multi_tensor_novograd", &multi_tensor_novograd_cuda,
"Compute and apply gradient update to parameters for Adam optimizer");
m.def("multi_tensor_lamb", &multi_tensor_lamb_cuda,
"Computes and apply update for LAMB optimizer");
}
A. 参考資料
A.1. マニュアルや資料
A.1.1. NVIDIA
- Training With Mixed Precision
- apex (A Pytorch EXtension)
- (blog)NVIDIA Apex: Tools for Easy Mixed-Precision Training in PyTorch
- (cudnn Developer Guide)2.9. Tensor Core Operations
A.1.2. PyTorch
A.1.3. Python
A.1.4. pybind11
A.1.5. 論文より
A.2. コード例
-
NVIDIA/apex
-
apex/amp
- amp.initialize(ampの初期化関数)
- amp.scale_loss(ampの損失係数設定)
-
amp._amp_state(AMPの状態管理用クラス)
-
amp._amp_state.opt_properties
- class Properties(object) (opt_propertiesに、O0からO3に相当する最適化パラメータを設定する。)
-
- パッチなどいくつか
- Unified mixed precision API + backend performance improvements #173 FP16_OptimizerからLossScalerへ変更(2019/02) 昔は、FP16_Optimizer前提の実装であったが、それが換わったことがわかる。
-
apex/amp