LoginSignup
10
3

More than 1 year has passed since last update.

jax + haiku + jmp ではじめるAutomatic Mixed Precision(AMP)なdeep learning

Last updated at Posted at 2021-12-24

はじめに

この記事は、jaxとそれを活用したライブラリであるhaikujmpを使ってAutomatic Mixed Precision(AMP)学習を実装するという内容になります。

Automatic mixed precision(AMP)学習

Automatic mixed precision(AMP)学習とは、演算精度に通常の単精度(float32)だけではなく半精度(float16)の両方を使うことで、GPUメモリ使用量の削減や学習の高速化が可能なテクニックになります。ケースによりますがGPUメモリ使用量が半分くらいになり、計算速度も数倍にもなるため、実用上重要なテクニックです。

以下のNVIDIA公式記事やQiita記事にも詳しく説明されているので、こちらも参考ください。

実際にAMP学習をするためには

  • 一部の演算をfloat16で行う(float16行列の積和演算において、加算はfloat32で行う)
  • 勾配がアンダーフローして0になることを避けるために、lossのスケーリングを行う(lossを数倍してから逆伝播)

これらを実装する必要があります。

jax

jaxとは一言で言うと、「Just in timeコンパイル(XLA)による高速化が可能なGPU/TPUに特化した自動微分機能付きNumpy」です。最近ではVision Transoformerの公式実装AlphaFoldの実装にも用いられていたりと広まりを見せています。

haikuとjmp

jaxはあくまでNumpyライクなAPIをだけを提供しているため、単体でDeep learningを実装していくのはつらいものがあります。そこでjaxをベースにしたDeep learning用のライブラリとしてDeepMindがhaikuを提供しています。

そしてAMP学習実装のためのライブラリも同じDeepMindから提供されています。

haiku, jmpともに同じDeepMindから提供されているだけあり、これらのインテグレーションは非常に簡単になっています。そのため今回はAMP学習実装のためにこれらのライブラリを選定しました。

以下では、これらのライブラリを使ってjaxでAMP学習の実装例を書いていきます。

検証環境

本記事では以下の環境で検証しています(2021/12時点)。

  • ハード
    • OS: Ubuntu18.04
    • GPU: Nvidia RTX2080Ti
    • CUDA: 11.3.1
    • cudnn: 8.2.1.32-1
  • ソフト
    • python: 3.8
    • jax == 0.2.25
    • jaxlib == 0.1.73
    • dm-haiku == 0.0.5
    • jmp == 0.0.2

モデル定義

haikuでモデルを定義するにはhk.Moduleを継承し、__init__内で各レイヤーを定義、__call__内で順伝播を書く、というのが一般的な方法になります。

Pytochのtorch.nn.Moduleやtensorflowのtf.keras.Modelと同様ですね。

import haiku as hk
import jax.numpy as jnp


class CNN(hk.Module):
    def __init__(self, num_classes: int) -> None:
        super().__init__()
        self.conv1 = hk.Conv2D(output_channels=32, kernel_shape=(3, 3))
        self.conv2 = hk.Conv2D(output_channels=64, kernel_shape=(3, 3))
        self.dense1 = hk.Linear(output_size=256)
        self.dense2 = hk.Linear(output_size=num_classes)
        self.pool = hk.AvgPool(window_shape=(2, 2), strides=(2, 2), padding="VALID")

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        x = self.conv1(x)
        x = jax.nn.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = jax.nn.relu(x)
        x = self.pool(x)
        x = x.reshape((x.shape[0], -1))  # flatten
        x = self.dense1(x)
        x = jax.nn.relu(x)
        x = self.dense2(x)
        return x

演算精度の設定

jmp.Policyクラスとhk.mixed_precision.set_policyを使うことで簡単にモデルの演算精度を設定することができます。

AMP学習の場合、パラメータ自体はfloat32で保存し、演算の際にfloat16にキャストする設定にします。

import jmp


# モデルの演算精度のポリシーを設定
policy = jmp.Policy(
    param_dtype=jnp.float32,
    compute_dtype=jnp.float16 if FLAGS.amp else jnp.float32,
    output_dtype=jnp.float32,
)
hk.mixed_precision.set_policy(CNN, policy)

hk.transform

ここからpytorchやtensorflowと少し違うのはhk.transformを使ってモデルを変換する処理が必要になるところです。

def _forward(x: jnp.ndarray) -> jnp.ndarray:
    net = CNN(num_classes=10)
    return net(x)

forward = hk.without_apply_rng(hk.transform(_forward))

上記のように関数内でModuleをインスタンス化し、__call__メソッドに入力を渡して、出力をreturnするようにします。

そして、hk.transformに関数を渡すことでhk.Transformedのインスタンスとして受け取ります。

このインスタンスはinitapplyメソッドを備えており、それぞれモデルパラメータの初期化、順伝播の実行用に用います。

また、普通ならば順伝播時にランダムキーを渡す必要があるところをhk.without_apply_rngでラップすることで、引数からランダムキーを削除することも可能です。順伝播時にランダム性が必要(Variational Auto Encoder, dropoutなど)な場合はランダムキーが必要ですが、そうでない場合は冗長なコードは嫌だという方はこれを使うことで少しだけスッキリします。

# hk.without_apply_rngがない場合
forward = hk.transorm(_forward)
logits = forward.apply(params, None, x)  # (prams, rng, inputs)

# ある場合
forward = hk.without_apply_rng(hk.transform(_forward))
logits = forward.apply(params, x)  # ランダムキーの入力が不要になる

パラメータ初期化

次にモデルが定義できたので、モデルのパラメータを初期化します。

ランダムキーを生成し、initメソッドに入力のダミー(shapeとdtypeが合っていいれば値は何でもOK)と入力すると、パラメータが返ってきます。モデルとパラメータが完全に分離している辺りが他のフレームワークと異なるところですね。

rng = jax.random.PRNGKey(0)
params = forward.init(rng, x)

また、それと同時にオプティマイザーのパラメータも初期化しておく必要があるのでそれを下記のようにまとめました。

from typing import NamedTuple

import jmp
import optax
from absl import flags, app

flags.DEFINE_bool("amp", False, help="Automatic mixed precision.")
FLAGS = flags.FLAGS


class TrainState(NamedTuple):
    """パラメータ管理用ヘルパークラス"""
    params: hk.Params
    opt_state: optax.OptState
    loss_scale: jmp.LossScale


def make_optimizer() -> optax.GradientTransformation:
    return optax.adam(1e-3)


def initial_state(
    rng: jnp.ndarray, images: np.ndarray, forward: hk.Transformed
) -> TrainState:
    params = forward.init(rng, images)
    opt_state = make_optimizer().init(params)
    loss_scale = jmp.DynamicLossScale(2.0 ** 15) if FLAGS.amp else jmp.NoOpLossScale()
    return TrainState(params, opt_state, loss_scale)

パラメータとして管理したいものをTrainStateというNamedTupleでまとめています。

またloss scalingをjmp.DynamicLossScaleが担っており、AMP学習の時はこちらを使います。

一方で、AMP学習でない時はjmp.NoOpLossScaleクラスというjmp.DynamicLossScaleと同じメソッドを持つが、何も処理をしないというクラスを使います。こうすることで後のコードを変更することなくAMP学習のOn/Offの切り替えができます。

学習ステップ

各種初期化が終わったので、次に学習の1ステップ(loss計算 → 勾配計算 → パラメータ更新)の処理を書いていきます。

Loss計算

まずloss関数を定義します。

def compute_loss(logits: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
    one_hot_labels = jax.nn.one_hot(labels, 10)
    loss = optax.softmax_cross_entropy(logits, one_hot_labels).mean()
    return loss


def train_step(
    train_state: TrainState,
    images: jnp.ndarray,
    labels: jnp.ndarray,
    forward: hk.Transformed,
) -> Tuple[TrainState, Dict[str, jnp.ndarray]]:
    params, opt_state, loss_scale = train_state

    # lossの定義
    def loss_fn(params: hk.Params) -> Tuple[jnp.ndarray, jnp.ndarray]:
        logits = forward.apply(params, images)
        loss = compute_loss(logits, labels)
        # Loss scaling
        loss = loss_scale.scale(loss)
        return loss, logits

forward.apply(params, images)で順伝播を行い、今回はsoftmax cross entropyをlossとして計算しています。

AMP学習の場合はアンダーフローを避けるためにlossのスケールを調整する必要がありloss_scale.scale(loss)によってloss_scale.loss_scale倍されています。

勾配の算出

続いて算出されたlossから勾配を求めます。

# 勾配の算出
(loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)
# Unscaling
grads = loss_scale.unscale(grads)

jax.value_and_gradを使うことでloss_fnから返り値と勾配を同時に求めることができます。loss_fnがloss以外の返り値を持つ場合はhas_aux=Trueとしておきます。

この勾配をloss_scale.unscaleしてスケールを戻した後にパラメータ更新など行っていきます。

パラメータ更新

最後にパラメータ更新を行います。

ここでAMP学習の場合は勾配の要素が全て有限(NaN/infでない)場合にのみ更新を行うようにしておきます。

# パラメータ更新
updates, new_opt_state = make_optimizer().update(grads, opt_state, params)
new_params = optax.apply_updates(params=params, updates=updates)
# DynamicLossScaleの場合、勾配のすべての要素が有限(NaN/infでない)の場合のみパラメータの更新を行う
skip_nonfinite_updates = isinstance(loss_scale, jmp.DynamicLossScale)
if skip_nonfinite_updates:
    grads_finite = jmp.all_finite(grads)
    loss_scale = loss_scale.adjust(grads_finite)
    new_params, new_opt_state = jmp.select_tree(
        grads_finite,
        (new_params, new_opt_state),
        (params, opt_state),
    )
new_train_state = TrainState(new_params, new_opt_state, loss_scale)

以上で学習の1ステップの処理は完了です。

評価ステップ

更新されたパラメータを用いてテストデータの評価を行います。

def eval_step(
    params: hk.Params, images: np.ndarray, labels: np.ndarray, forward: hk.Transformed
) -> Dict[str, jnp.ndarray]:
    logits = forward.apply(params, images)
    loss = compute_loss(logits, labels)
    accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
    scalars = {"val_loss": loss, "val_accuracy": accuracy}
    return scalars

検証

今回はMNISTを使ってAMPの効果を検証していきます。

jaxはデフォルトではメモリの90%をpreallocateするので、メモリ使用率を確認できるように設定で変更しておきます。
環境変数にXLA_PYTHON_CLIENT_PREALLOCATE=falseを設定することでpreallocationを禁止しておきます。

データセット作成

jax, haikuともにデータ入力パイプラインとなる便利な関数などは用意されていません。

そのため他のフレームワークで提供されている入力パイプラインを使うことが推奨されています(tf.data.Datasettorch.utils.data.DataLoaderなど)。今回はtensorflow-datasetを使った例になります。また、MNISTの画像サイズ(28x28)ではAMPの効果がほとんどないため128x128にリサイズして検証しています(オリジナルサイズでの結果はおまけに後述)。

def preprocess(
    batch: Dict[str, tf.Tensor], dtype: tf.dtypes
) -> Tuple[tf.Tensor, tf.Tensor]:
    images = batch["image"]
    # AMP効果把握のため128x128にリサイズ
    images = tf.image.resize(images=images, size=(128, 128))
    # AMP Onの場合float16にキャスト
    images = tf.cast(images, dtype=dtype) / 255.0
    labels = batch["label"]
    return images, labels

# データセットの作成
batch_size = FLAGS.batch_size
train_ds, test_ds = tfds.load("mnist", split=["train", "test"])
preprocess_fn = functools.partial(
    preprocess, dtype=tf.float16 if FLAGS.amp else tf.float32
)
train_ds = (
    train_ds.cache()
    .batch(batch_size, drop_remainder=True)
    .map(
        preprocess_fn,
        num_parallel_calls=tf.data.experimental.AUTOTUNE,
    )
    .prefetch(True)
)
test_ds = (
    test_ds.cache()
    .batch(batch_size)
    .map(
        preprocess_fn,
        num_parallel_calls=tf.data.experimental.AUTOTUNE,
    )
)

学習実行

jax.jitでラップすることでjitコンパイルされ処理が高速化されます。というより、しないと遅すぎるという方が正しいくらいですので必ずjitコンパイルしましょう!

# パラメータ初期化のためのダミー入力を作成
dummy_images, _ = next(train_ds.as_numpy_iterator())
#  パラメータ初期化
rng = jax.random.PRNGKey(0)
train_state = initial_state(rng, dummy_images, forward)

# Train/eval loop.
for i in range(FLAGS.epochs):
    batch_scalars, eval_batch_scalars = [], []
    epoch_scalars = {}
    start = time.perf_counter()
    # Training
    for batch in train_ds.as_numpy_iterator():
        images, labels = batch
        # jitコンパイルで高速化
        train_state, scalars = jax.jit(train_step, static_argnums=3)(
            train_state, images, labels, forward
        )
        batch_scalars.append(scalars)
    training_time = time.perf_counter() - start

結果

AMP学習On/Off時のメモリ使用量と1epoch当たりの学習時間を計測した結果↓

AMP メモリ使用量 学習時間/epoch
ON 2023 MB 4.67 s
OFF 6119 MB 8.41 s

ということで今回のケースではAMP学習にすることでメモリ使用量が約1/3、学習時間が約1/2に削減されました!!

わずかなコード修正で凄まじい効果ですね!!

まとめ

haiku + jmpを用いることでjaxでAMP学習の実装とその効果を確認できました。

まだまだPytorchやtensorflowほどハイレベルAPI・ライブラリが充実している訳ではないですが、jax周りのエコシステムは日々発展しているので今後にますます期待ですね!!

(今回実装したコードはこちら)
import functools
import time
from typing import Tuple, NamedTuple, Dict

import haiku as hk
import jax
import jax.numpy as jnp
import jmp
import numpy as np
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
from absl import flags, app

flags.DEFINE_bool("amp", True, help="Automatic mixed precision flag.")
flags.DEFINE_integer("batch_size", 128, help="Batch size.")
flags.DEFINE_integer("epochs", 10, help="Epochs.")
FLAGS = flags.FLAGS


class TrainState(NamedTuple):
    params: hk.Params
    opt_state: optax.OptState
    loss_scale: jmp.LossScale


class CNN(hk.Module):
    def __init__(self, num_classes: int) -> None:
        super().__init__()
        self.conv1 = hk.Conv2D(output_channels=32, kernel_shape=(3, 3))
        self.conv2 = hk.Conv2D(output_channels=64, kernel_shape=(3, 3))
        self.dense1 = hk.Linear(output_size=256)
        self.dense2 = hk.Linear(output_size=num_classes)
        self.pool = hk.AvgPool(window_shape=(2, 2), strides=(2, 2), padding="VALID")

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        x = self.conv1(x)
        x = jax.nn.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = jax.nn.relu(x)
        x = self.pool(x)
        x = x.reshape((x.shape[0], -1))  # flatten
        x = self.dense1(x)
        x = jax.nn.relu(x)
        x = self.dense2(x)
        return x


def _forward(x: jnp.ndarray) -> jnp.ndarray:
    net = CNN(num_classes=10)
    return net(x)


def make_optimizer() -> optax.GradientTransformation:
    return optax.adam(1e-3)


def compute_loss(logits: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
    one_hot_labels = jax.nn.one_hot(labels, 10)
    loss = optax.softmax_cross_entropy(logits, one_hot_labels).mean()
    return loss


def train_step(
    train_state: TrainState,
    images: jnp.ndarray,
    labels: jnp.ndarray,
    forward: hk.Transformed,
) -> Tuple[TrainState, Dict[str, jnp.ndarray]]:
    params, opt_state, loss_scale = train_state

    # lossの定義
    def loss_fn(params: hk.Params) -> Tuple[jnp.ndarray, jnp.ndarray]:
        logits = forward.apply(params, images)
        loss = compute_loss(logits, labels)
        # Loss scaling
        loss = loss_scale.scale(loss)
        return loss, logits

    # 勾配の算出
    (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)
    # Unscaling
    grads = loss_scale.unscale(grads)

    # パラメータ更新
    updates, new_opt_state = make_optimizer().update(grads, opt_state, params)
    new_params = optax.apply_updates(params=params, updates=updates)
    # DynamicLossScaleの場合、勾配のすべての要素が有限の場合のみパラメータの更新を行う
    skip_nonfinite_updates = isinstance(loss_scale, jmp.DynamicLossScale)
    if skip_nonfinite_updates:
        grads_finite = jmp.all_finite(grads)
        loss_scale = loss_scale.adjust(grads_finite)
        new_params, new_opt_state = jmp.select_tree(
            grads_finite,
            (new_params, new_opt_state),
            (params, opt_state),
        )
    new_train_state = TrainState(new_params, new_opt_state, loss_scale)

    accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
    scalars = {
        "train_loss": loss_scale.unscale(loss),
        "train_accuracy": accuracy,
        "loss_scale": loss_scale.loss_scale,
    }
    scalars = jmp.cast_to_full(scalars)
    return new_train_state, scalars


def eval_step(
    params: hk.Params, images: np.ndarray, labels: np.ndarray, forward: hk.Transformed
) -> Dict[str, jnp.ndarray]:
    logits = forward.apply(params, images)
    loss = compute_loss(logits, labels)
    accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
    scalars = {"val_loss": loss, "val_accuracy": accuracy}
    return scalars


def initial_state(
    rng: jnp.ndarray, images: np.ndarray, forward: hk.Transformed
) -> TrainState:
    params = forward.init(rng, images)
    opt_state = make_optimizer().init(params)
    loss_scale = jmp.DynamicLossScale(2.0 ** 15) if FLAGS.amp else jmp.NoOpLossScale()
    return TrainState(params, opt_state, loss_scale)


def preprocess(
    batch: Dict[str, tf.Tensor], dtype: tf.dtypes
) -> Tuple[tf.Tensor, tf.Tensor]:
    images = batch["image"]
    images = tf.image.resize(images=images, size=(128, 128))
    images = tf.cast(images, dtype=dtype) / 255.0
    labels = batch["label"]
    return images, labels


def compute_mean_scalars_in_epoch(
    batch_scalars: Dict[str, jnp.ndarray]
) -> Dict[str, float]:
    batch_scalars_np = jax.device_get(batch_scalars)
    epoch_scalars_np = {
        k: np.mean([metrics[k] for metrics in batch_scalars_np])
        for k in batch_scalars_np[0]
    }
    return epoch_scalars_np


def main(argv):
    # モデルの演算精度のポリシーを設定
    policy = jmp.Policy(
        param_dtype=jnp.float32,
        compute_dtype=jnp.float16 if FLAGS.amp else jnp.float32,
        output_dtype=jnp.float32,
    )
    hk.mixed_precision.set_policy(CNN, policy)

    # モデルの変換
    forward = hk.without_apply_rng(hk.transform(_forward))

    # データセットの作成
    batch_size = FLAGS.batch_size
    train_ds, test_ds = tfds.load("mnist", split=["train", "test"])
    preprocess_fn = functools.partial(
        preprocess, dtype=tf.float16 if FLAGS.amp else tf.float32
    )
    train_ds = (
        train_ds.cache()
        .batch(batch_size, drop_remainder=True)
        .map(
            preprocess_fn,
            num_parallel_calls=tf.data.experimental.AUTOTUNE,
        )
        .prefetch(True)
    )
    test_ds = (
        test_ds.cache()
        .batch(batch_size)
        .map(
            preprocess_fn,
            num_parallel_calls=tf.data.experimental.AUTOTUNE,
        )
    )

    # パラメータ初期化のためのダミー入力を作成
    dummy_images, _ = next(train_ds.as_numpy_iterator())
    #  パラメータ初期化
    rng = jax.random.PRNGKey(0)
    train_state = initial_state(rng, dummy_images, forward)

    # Train/eval loop.
    for i in range(FLAGS.epochs):
        batch_scalars, eval_batch_scalars = [], []
        epoch_scalars = {}
        start = time.perf_counter()
        # Training
        for batch in train_ds.as_numpy_iterator():
            images, labels = batch
            train_state, scalars = jax.jit(train_step, static_argnums=3)(
                train_state, images, labels, forward
            )
            batch_scalars.append(scalars)
        training_time = time.perf_counter() - start
        epoch_scalars.update(compute_mean_scalars_in_epoch(batch_scalars))

        # Evaluation
        params = train_state.params
        for batch_eval in test_ds.as_numpy_iterator():
            images, labels = batch_eval
            eval_scalars = jax.jit(eval_step, static_argnums=3)(
                params, images, labels, forward
            )
            eval_batch_scalars.append(eval_scalars)
        epoch_scalars.update(compute_mean_scalars_in_epoch(eval_batch_scalars))

        msg = f"Epoch {i} - {training_time:.5f}s"
        for key, value in epoch_scalars.items():
            msg += f" - {key}:{value:.5f}"
        print(msg)


if __name__ == "__main__":
    app.run(main)

おまけ

MNISTのオリジナル画像サイズ(28x28)で検証した場合

AMP メモリ使用量 学習時間/epoch
ON 1191 MB 1.01 s
OFF 1191 MB 0.77 s

ということでAMP学習しない方が良いという結果になりました...

モデルサイズ、入力サイズともに小さい場合はAMPのためのオーバーヘッド分だけ遅くなっているということだと考えられます。

AMP学習も全てのケースで有効に働く訳ではなく、モデルサイズ、入力サイズなどに依存するので、導入する際にはきちんと効果があるか確認して、用法用量を正しく守って使いましょう。

10
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
3