12
19

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

(ソースコードメモ)NVIDIA/apex

Last updated at Posted at 2019-10-20

1. はじめに

NVIDIA/apex(のamp)は、FP16の演算を前提とした(順および逆)伝播計算等を支援する関数群である。これにより、2倍から4倍程度早くなる。そのソースコードを読んだメモである。NVIDIA/apexでは、下記の処理を行っている。

  • Python層での関数ラッピング
  • C++関数のPythonへのマッピング

なお、NVIDIA/apexは、OSS公開後、かなりのコード書き換えが行われつつあり、インターフェースがよりシンプルに変わりつつある。ここで記載するのは、2019年10月現在のものである。また、ampは、PyTorchへの取り込みの議論が8月末より行われている。このため、PyTorch 1.4以降で本体に取り込まれる見込みである。(2020年6月追記:本作業は進行中のようであり1.6を目標としているようである。)

また、apex配下で、プロファイラ(pyprof)等も作られているようではある。しかし、nvprof以上のメリットはないと思われるので言及しない。

2. Python関数のラッピングについて

2.1. apex.amp (FP32/FP16の混合計算)

FP32/FP16の混合演算(TensorCore)で、高速に処理するためのしくみが、apex.ampである。これによって、Float32からFloat16に型変換する。また、この変換処理の一つとして、Float16で損失計算の精度維持を行うために必要な損失スケーリング(Loss Scaling)の設定も行われる。Float16の精度補正のためなので、Float32の演算では、モデルや最適化関数の型変更はされない。なお、apex.ampは、関数のデータ型を置き換えるだけであり、Float16の際に、Tensor Coreで演算するか否かはPyTorch側の設定となる。
apex.ampを使うためには、通常の処理に対して、3行追加する必要がある。

# Added after model and optimizer construction
model, optimizer = amp.initialize(model, optimizer, flags...)
...
# loss.backward() changed to:
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()
  • 1項目目(amp.initialize)は、
    • ampのモード等の設定(O0からO3等の設定を_amp_stateに書き込む)を行う。損失係数の初期化(LossScalingクラス)もその延長で行う
    • モデル(model)および最適化関数(optimizer)の各層がどの型(float16やfloat32等)で行うかのカスタマイズを行っている。
  • 2項目目(amp.scale_loss)は、逆伝搬の計算の前後に損失係数処理(ロススケーリング)を追加している。

ここでは、1項目目および2項目目について説明する。
はじめに、動作環境の確認を行う。たとえば、cudnnが使えるかなどを確認する。

    if not torch.backends.cudnn.enabled:
        raise RuntimeError(
            "Amp requires torch.backends.cudnn.enabled = True")

次に、モデルのラッピング(置き換え)を行う。モデルのデータ型変換および、入出力のデータ変更等である。functools.partialや、独自に定義したapplier等によりデータ型変換を行う。

    if properties.cast_model_type:
        if properties.keep_batchnorm_fp32:
            for model in models:
                convert_network(model, properties.cast_model_type)
        else:
            for model in models:
                model.to(properties.cast_model_type)

        input_caster = functools.partial(to_type, properties.cast_model_type)
        if cast_model_outputs is not None:
            output_caster = functools.partial(to_type, cast_model_outputs)
        else:
            output_caster = functools.partial(to_type, torch.float32)

        for model in models:
            # Patch the forward method to cast incoming data to the correct type, and
            # outgoing data to float32, so "the user never needs to call .half()."
            # I like writing things explicitly more than decorators.
            def patch_forward(old_fwd):
                def new_fwd(*args, **kwargs):
                    output = old_fwd(*applier(args, input_caster),
                                     **applier(kwargs, input_caster))
                    return applier(output, output_caster)
                return new_fwd

            model.forward = patch_forward(model.forward)

        # State dict trick to recast any preexisting per-param state tensors
        for optimizer in optimizers:
            optimizer.load_state_dict(optimizer.state_dict())

        # patch model.state_dict() to return float32 params
        for model in models:
            for module in model.modules():
                module._register_state_dict_hook(O2StateDictHook(functools.partial(to_type, torch.float32)))

    elif cast_model_outputs is not None:
        output_caster = functools.partial(to_type, cast_model_outputs)

        for model in models:
            def patch_forward(old_fwd):
                def new_fwd(*args, **kwargs):
                    output = old_fwd(*args, **kwargs)
                    return applier(output, output_caster)
                return new_fwd

            model.forward = patch_forward(model.forward)

さらに、O2の初期化処理では、最適化(optimize)ステップは、キャスティング(データ型変換)されずもとの型で計算されるよう設定される。具体的には以下の箇所である。これにより、master weightの精度がFP32で維持される。

    if properties.patch_torch_functions:
        # handle is unused here. It's accessible later through a global value anyway.
        handle = amp_init(loss_scale=properties.loss_scale, verbose=(_amp_state.verbosity == 2))
        for optimizer in optimizers:
            # Disable Amp casting for the optimizer step, because it should only be
            # applied to FP32 master params anyway.
            def patch_step(old_step):
                def new_step(*args, **kwargs):
                    with disable_casts():
                        output = old_step(*args, **kwargs)
                    return output
                return new_step

            optimizer.step = patch_step(optimizer.step)

また、損失係数の考慮(Loss Scaling)が必要になる。amp_initで、AmpHandlerクラスの中の要素として、LossScaler内で、_loss_scaleにパラメータを設定する。デフォルトは、65536(=2^16)である。

class LossScaler(object):
    warned_no_fused_kernel = False
    warned_unscaling_non_fp32_grad = False
    has_fused_kernel = False

    def __init__(self,
                 loss_scale,
                 init_scale=2.**16,
                 scale_factor=2.,
                 scale_window=2000,
                 min_loss_scale=None,
                 max_loss_scale=2.**24):
        if loss_scale == "dynamic":
            self.dynamic = True
            self._loss_scale = min(max_loss_scale, init_scale)
        else:
            self.dynamic = False
            self._loss_scale = loss_scale
        self._max_loss_scale = max_loss_scale
        self._min_loss_scale = min_loss_scale
        self._scale_seq_len = scale_window
        self._unskipped = 0
        self._has_overflow = False
        self._overflow_buf = torch.cuda.IntTensor([0])
        if multi_tensor_applier.available:
            import amp_C
            LossScaler.has_fused_kernel = multi_tensor_applier.available
            LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale
            LossScaler.multi_tensor_axpby_cuda = amp_C.multi_tensor_axpby
        else:
            if not LossScaler.warned_no_fused_kernel:
                maybe_print(
                    "Warning:  multi_tensor_applier fused unscale kernel is unavailable, "
                    "possibly because apex was installed without --cuda_ext --cpp_ext. "
                    "Using Python fallback.  Original ImportError was: " +
                    repr(multi_tensor_applier.import_err),
                    True)
            LossScaler.has_fused_kernel = False
            LossScaler.warned_no_fused_kernel = True

損失係数の計算は、amp.scale_lossで行う。この関数は、contextlibを構成するwithで呼ばれることから、__enter__, __exit__ 相当の処理およびyield以下の設定処理からなる。特に、loss_scaleを設定している場合、以下が呼び出され損失係数が考慮される。

    yield (loss.float())*loss_scale

損失係数LossScalingの演算は、上記yieldの後に、LossScalerクラスで行っている。このLossScaler関数は、apex.amp.scaler.LossScalerである。損失係数は、update_scale()メソッドにより更新される。
なお、モデル等の逆伝搬(grad値の要素)に、Inf/NAN等があるかをチェックする。あれば、損失係数(LossScaling)の変更を行う。

さて、ampの各オプションにおける最適化機能は以下の通りになっている。

opt_level O0 O1 O2 O3
cast_model_type torch.float32 None torch.float16 torch.float16
patch_torch_functions False True False False
keep_batchnorm_fp32 None None True False
master_weights False None True False
loss_scale 1.0 dynamic dynamic 1.0
ここで、上記のパラメータは以下の通りである。ロススケーリングは、O1/O2でしか動かないことに注意を要する。
  • 最適化パラメータ
    • O0 FP32での演算を示す
    • O1 TensorCoreを用いたFP32/FP16混合の演算を示す。(NVIDIAではお勧めの設定)
    • O2 ほとんどFP16の混合演算 master_weightsがFP32で行われる以外はFP16で行われる
    • O3 FP16での演算を示す。
  • 各パラメータ
    • cast_model_typeは、モデルパラメータをどの型に変換するかをしていするパラメータである。
    • patch_torch_functionsは、メソッド等をTensorCore向けに変換するかを設定するパラメータである。
    • keep_batch_norm_fp32は、バッチ正規化の演算を、どの精度で行うかを指定するパラメータである。
    • master_weightsは、演算時の重みをFP32で行うかを設定するパラメータである。
    • loss_scaleは、損失スケーリングを行うか否かを設定するパラメータである。

3. C++関数のPythonへのマッピングについて

3.1. PyTorch側の実装

PyTorchでは、PYBIND11をベースにC++拡張がしやすい仕組みが入っている。具体的には、setuptools.Extensionの型として、C++/CUDAの関数を定義するヘルパークラスが、CppExtension/CUDAExtensionとして提供されている。

3.1.1. Python側

  • class BuildExtension
    • C++ファイルで用いるTORCH_EXTENSION_NAME変数の置き換え処理スクリプトを、ninja用に生成するクラス。

3.1.2. C++側

  • torch/extension.h
    • このヘッダーファイルで、インターフェース定義を取り込んで、関数定義を行う。

3.2. NVIDIA/apexでの実装

3.2.1. Python側

  • setup.py
    • 各ライブラリのPythonでのエントリーポイントを定義する。ext_module.appendが定義している箇所である。以下のモジュールが定義されており、Pythonからimportで呼び出すことが出来る。
      • apex_C テンソルを1次元化する関数(torch._utils._flatten_dense_tensorsを高速化したもの)
      • amp_C
      • fused_adam_cuda
      • fused_layer_norm_cuda
      • sync_bn

例えば、以下のような定義を行っている。

        ext_modules.append(
            CUDAExtension(name='amp_C',
                          sources=['csrc/amp_C_frontend.cpp',
                                   'csrc/multi_tensor_sgd_kernel.cu',
                                   'csrc/multi_tensor_scale_kernel.cu',
                                   'csrc/multi_tensor_axpby_kernel.cu',
                                   'csrc/multi_tensor_l2norm_kernel.cu',
                                   'csrc/multi_tensor_lamb_stage_1.cu',
                                   'csrc/multi_tensor_lamb_stage_2.cu',
                                   'csrc/multi_tensor_adam.cu',
                                   'csrc/multi_tensor_novograd.cu',
                                   'csrc/multi_tensor_lamb.cu'],
                          extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
                                              'nvcc':['-lineinfo',
                                                      '-O3',
                                                      # '--resource-usage',
                                                      '--use_fast_math'] + version_dependent_macros}))

3.2.2. C++側

例えば、以下のように定義する。なお、TORCH_EXTENSION_NAMEは、setup.pyの延長で呼び出される_define_torch_extension_nameで設定される。

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
        "Fused overflow check + scale for a list of contiguous tensors");
  m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
        "Fused SGD optimizer for list of contiguous tensors");
  m.def("multi_tensor_axpby", &multi_tensor_axpby_cuda,
        "out = a*x + b*y for a list of contiguous tensors");
  m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
        "Computes L2 norm for a list of contiguous tensors");
  m.def("multi_tensor_lamb_stage1_cuda", &multi_tensor_lamb_stage1_cuda,
        "Computes update part of LAMB optimizer");
  m.def("multi_tensor_lamb_stage2_cuda", &multi_tensor_lamb_stage2_cuda,
        "Completes application of gradient to parameters for LAMB optimizer");
  m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
        "Compute and apply gradient update to parameters for Adam optimizer");
  m.def("multi_tensor_novograd", &multi_tensor_novograd_cuda,
        "Compute and apply gradient update to parameters for Adam optimizer");
  m.def("multi_tensor_lamb", &multi_tensor_lamb_cuda,
        "Computes and apply update for LAMB optimizer");
}

A. 参考資料

A.1. マニュアルや資料

A.1.1. NVIDIA

A.1.2. PyTorch

A.1.3. Python

A.1.4. pybind11

A.1.5. 論文より

A.2. コード例

12
19
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
19

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?