はじめに
この記事は、jaxとそれを活用したライブラリであるhaikuとjmpを使ってAutomatic Mixed Precision(AMP)学習を実装するという内容になります。
Automatic mixed precision(AMP)学習
Automatic mixed precision(AMP)学習とは、演算精度に通常の単精度(float32)だけではなく半精度(float16)の両方を使うことで、GPUメモリ使用量の削減や学習の高速化が可能なテクニックになります。ケースによりますがGPUメモリ使用量が半分くらいになり、計算速度も数倍にもなるため、実用上重要なテクニックです。
以下のNVIDIA公式記事やQiita記事にも詳しく説明されているので、こちらも参考ください。
実際にAMP学習をするためには
- 一部の演算をfloat16で行う(float16行列の積和演算において、加算はfloat32で行う)
- 勾配がアンダーフローして0になることを避けるために、lossのスケーリングを行う(lossを数倍してから逆伝播)
これらを実装する必要があります。
jax
jaxとは一言で言うと、「Just in timeコンパイル(XLA)による高速化が可能なGPU/TPUに特化した自動微分機能付きNumpy」です。最近ではVision Transoformerの公式実装やAlphaFoldの実装にも用いられていたりと広まりを見せています。
haikuとjmp
jaxはあくまでNumpyライクなAPIをだけを提供しているため、単体でDeep learningを実装していくのはつらいものがあります。そこでjaxをベースにしたDeep learning用のライブラリとしてDeepMindがhaikuを提供しています。
そしてAMP学習実装のためのライブラリも同じDeepMindから提供されています。
haiku, jmpともに同じDeepMindから提供されているだけあり、これらのインテグレーションは非常に簡単になっています。そのため今回はAMP学習実装のためにこれらのライブラリを選定しました。
以下では、これらのライブラリを使ってjaxでAMP学習の実装例を書いていきます。
検証環境
本記事では以下の環境で検証しています(2021/12時点)。
- ハード
- OS: Ubuntu18.04
- GPU: Nvidia RTX2080Ti
- CUDA: 11.3.1
- cudnn: 8.2.1.32-1
- ソフト
- python: 3.8
- jax == 0.2.25
- jaxlib == 0.1.73
- dm-haiku == 0.0.5
- jmp == 0.0.2
モデル定義
haikuでモデルを定義するにはhk.Module
を継承し、__init__
内で各レイヤーを定義、__call__
内で順伝播を書く、というのが一般的な方法になります。
Pytochのtorch.nn.Module
やtensorflowのtf.keras.Model
と同様ですね。
import haiku as hk
import jax.numpy as jnp
class CNN(hk.Module):
def __init__(self, num_classes: int) -> None:
super().__init__()
self.conv1 = hk.Conv2D(output_channels=32, kernel_shape=(3, 3))
self.conv2 = hk.Conv2D(output_channels=64, kernel_shape=(3, 3))
self.dense1 = hk.Linear(output_size=256)
self.dense2 = hk.Linear(output_size=num_classes)
self.pool = hk.AvgPool(window_shape=(2, 2), strides=(2, 2), padding="VALID")
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
x = self.conv1(x)
x = jax.nn.relu(x)
x = self.pool(x)
x = self.conv2(x)
x = jax.nn.relu(x)
x = self.pool(x)
x = x.reshape((x.shape[0], -1)) # flatten
x = self.dense1(x)
x = jax.nn.relu(x)
x = self.dense2(x)
return x
演算精度の設定
jmp.Policy
クラスとhk.mixed_precision.set_policy
を使うことで簡単にモデルの演算精度を設定することができます。
AMP学習の場合、パラメータ自体はfloat32で保存し、演算の際にfloat16にキャストする設定にします。
import jmp
# モデルの演算精度のポリシーを設定
policy = jmp.Policy(
param_dtype=jnp.float32,
compute_dtype=jnp.float16 if FLAGS.amp else jnp.float32,
output_dtype=jnp.float32,
)
hk.mixed_precision.set_policy(CNN, policy)
hk.transform
ここからpytorchやtensorflowと少し違うのはhk.transform
を使ってモデルを変換する処理が必要になるところです。
def _forward(x: jnp.ndarray) -> jnp.ndarray:
net = CNN(num_classes=10)
return net(x)
forward = hk.without_apply_rng(hk.transform(_forward))
上記のように関数内でModuleをインスタンス化し、__call__
メソッドに入力を渡して、出力をreturnするようにします。
そして、hk.transform
に関数を渡すことでhk.Transformed
のインスタンスとして受け取ります。
このインスタンスはinit
とapply
メソッドを備えており、それぞれモデルパラメータの初期化、順伝播の実行用に用います。
また、普通ならば順伝播時にランダムキーを渡す必要があるところをhk.without_apply_rng
でラップすることで、引数からランダムキーを削除することも可能です。順伝播時にランダム性が必要(Variational Auto Encoder, dropoutなど)な場合はランダムキーが必要ですが、そうでない場合は冗長なコードは嫌だという方はこれを使うことで少しだけスッキリします。
# hk.without_apply_rngがない場合
forward = hk.transorm(_forward)
logits = forward.apply(params, None, x) # (prams, rng, inputs)
# ある場合
forward = hk.without_apply_rng(hk.transform(_forward))
logits = forward.apply(params, x) # ランダムキーの入力が不要になる
パラメータ初期化
次にモデルが定義できたので、モデルのパラメータを初期化します。
ランダムキーを生成し、init
メソッドに入力のダミー(shapeとdtypeが合っていいれば値は何でもOK)と入力すると、パラメータが返ってきます。モデルとパラメータが完全に分離している辺りが他のフレームワークと異なるところですね。
rng = jax.random.PRNGKey(0)
params = forward.init(rng, x)
また、それと同時にオプティマイザーのパラメータも初期化しておく必要があるのでそれを下記のようにまとめました。
from typing import NamedTuple
import jmp
import optax
from absl import flags, app
flags.DEFINE_bool("amp", False, help="Automatic mixed precision.")
FLAGS = flags.FLAGS
class TrainState(NamedTuple):
"""パラメータ管理用ヘルパークラス"""
params: hk.Params
opt_state: optax.OptState
loss_scale: jmp.LossScale
def make_optimizer() -> optax.GradientTransformation:
return optax.adam(1e-3)
def initial_state(
rng: jnp.ndarray, images: np.ndarray, forward: hk.Transformed
) -> TrainState:
params = forward.init(rng, images)
opt_state = make_optimizer().init(params)
loss_scale = jmp.DynamicLossScale(2.0 ** 15) if FLAGS.amp else jmp.NoOpLossScale()
return TrainState(params, opt_state, loss_scale)
パラメータとして管理したいものをTrainState
というNamedTupleでまとめています。
またloss scalingをjmp.DynamicLossScale
が担っており、AMP学習の時はこちらを使います。
一方で、AMP学習でない時はjmp.NoOpLossScale
クラスというjmp.DynamicLossScale
と同じメソッドを持つが、何も処理をしないというクラスを使います。こうすることで後のコードを変更することなくAMP学習のOn/Offの切り替えができます。
学習ステップ
各種初期化が終わったので、次に学習の1ステップ(loss計算 → 勾配計算 → パラメータ更新)の処理を書いていきます。
Loss計算
まずloss関数を定義します。
def compute_loss(logits: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
one_hot_labels = jax.nn.one_hot(labels, 10)
loss = optax.softmax_cross_entropy(logits, one_hot_labels).mean()
return loss
def train_step(
train_state: TrainState,
images: jnp.ndarray,
labels: jnp.ndarray,
forward: hk.Transformed,
) -> Tuple[TrainState, Dict[str, jnp.ndarray]]:
params, opt_state, loss_scale = train_state
# lossの定義
def loss_fn(params: hk.Params) -> Tuple[jnp.ndarray, jnp.ndarray]:
logits = forward.apply(params, images)
loss = compute_loss(logits, labels)
# Loss scaling
loss = loss_scale.scale(loss)
return loss, logits
forward.apply(params, images)
で順伝播を行い、今回はsoftmax cross entropyをlossとして計算しています。
AMP学習の場合はアンダーフローを避けるためにlossのスケールを調整する必要がありloss_scale.scale(loss)
によってloss_scale.loss_scale
倍されています。
勾配の算出
続いて算出されたlossから勾配を求めます。
# 勾配の算出
(loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)
# Unscaling
grads = loss_scale.unscale(grads)
jax.value_and_grad
を使うことでloss_fn
から返り値と勾配を同時に求めることができます。loss_fn
がloss以外の返り値を持つ場合はhas_aux=True
としておきます。
この勾配をloss_scale.unscale
してスケールを戻した後にパラメータ更新など行っていきます。
パラメータ更新
最後にパラメータ更新を行います。
ここでAMP学習の場合は勾配の要素が全て有限(NaN/infでない)場合にのみ更新を行うようにしておきます。
# パラメータ更新
updates, new_opt_state = make_optimizer().update(grads, opt_state, params)
new_params = optax.apply_updates(params=params, updates=updates)
# DynamicLossScaleの場合、勾配のすべての要素が有限(NaN/infでない)の場合のみパラメータの更新を行う
skip_nonfinite_updates = isinstance(loss_scale, jmp.DynamicLossScale)
if skip_nonfinite_updates:
grads_finite = jmp.all_finite(grads)
loss_scale = loss_scale.adjust(grads_finite)
new_params, new_opt_state = jmp.select_tree(
grads_finite,
(new_params, new_opt_state),
(params, opt_state),
)
new_train_state = TrainState(new_params, new_opt_state, loss_scale)
以上で学習の1ステップの処理は完了です。
評価ステップ
更新されたパラメータを用いてテストデータの評価を行います。
def eval_step(
params: hk.Params, images: np.ndarray, labels: np.ndarray, forward: hk.Transformed
) -> Dict[str, jnp.ndarray]:
logits = forward.apply(params, images)
loss = compute_loss(logits, labels)
accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
scalars = {"val_loss": loss, "val_accuracy": accuracy}
return scalars
検証
今回はMNISTを使ってAMPの効果を検証していきます。
jaxはデフォルトではメモリの90%をpreallocateするので、メモリ使用率を確認できるように設定で変更しておきます。
環境変数にXLA_PYTHON_CLIENT_PREALLOCATE=false
を設定することでpreallocationを禁止しておきます。
データセット作成
jax, haikuともにデータ入力パイプラインとなる便利な関数などは用意されていません。
そのため他のフレームワークで提供されている入力パイプラインを使うことが推奨されています(tf.data.Dataset
やtorch.utils.data.DataLoader
など)。今回はtensorflow-datasetを使った例になります。また、MNISTの画像サイズ(28x28)ではAMPの効果がほとんどないため128x128にリサイズして検証しています(オリジナルサイズでの結果はおまけに後述)。
def preprocess(
batch: Dict[str, tf.Tensor], dtype: tf.dtypes
) -> Tuple[tf.Tensor, tf.Tensor]:
images = batch["image"]
# AMP効果把握のため128x128にリサイズ
images = tf.image.resize(images=images, size=(128, 128))
# AMP Onの場合float16にキャスト
images = tf.cast(images, dtype=dtype) / 255.0
labels = batch["label"]
return images, labels
# データセットの作成
batch_size = FLAGS.batch_size
train_ds, test_ds = tfds.load("mnist", split=["train", "test"])
preprocess_fn = functools.partial(
preprocess, dtype=tf.float16 if FLAGS.amp else tf.float32
)
train_ds = (
train_ds.cache()
.batch(batch_size, drop_remainder=True)
.map(
preprocess_fn,
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)
.prefetch(True)
)
test_ds = (
test_ds.cache()
.batch(batch_size)
.map(
preprocess_fn,
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)
)
学習実行
jax.jit
でラップすることでjitコンパイルされ処理が高速化されます。というより、しないと遅すぎるという方が正しいくらいですので必ずjitコンパイルしましょう!
# パラメータ初期化のためのダミー入力を作成
dummy_images, _ = next(train_ds.as_numpy_iterator())
# パラメータ初期化
rng = jax.random.PRNGKey(0)
train_state = initial_state(rng, dummy_images, forward)
# Train/eval loop.
for i in range(FLAGS.epochs):
batch_scalars, eval_batch_scalars = [], []
epoch_scalars = {}
start = time.perf_counter()
# Training
for batch in train_ds.as_numpy_iterator():
images, labels = batch
# jitコンパイルで高速化
train_state, scalars = jax.jit(train_step, static_argnums=3)(
train_state, images, labels, forward
)
batch_scalars.append(scalars)
training_time = time.perf_counter() - start
結果
AMP学習On/Off時のメモリ使用量と1epoch当たりの学習時間を計測した結果↓
AMP | メモリ使用量 | 学習時間/epoch |
---|---|---|
ON | 2023 MB | 4.67 s |
OFF | 6119 MB | 8.41 s |
ということで今回のケースではAMP学習にすることでメモリ使用量が約1/3、学習時間が約1/2に削減されました!!
わずかなコード修正で凄まじい効果ですね!!
まとめ
haiku + jmpを用いることでjaxでAMP学習の実装とその効果を確認できました。
まだまだPytorchやtensorflowほどハイレベルAPI・ライブラリが充実している訳ではないですが、jax周りのエコシステムは日々発展しているので今後にますます期待ですね!!
(今回実装したコードはこちら)
import functools
import time
from typing import Tuple, NamedTuple, Dict
import haiku as hk
import jax
import jax.numpy as jnp
import jmp
import numpy as np
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
from absl import flags, app
flags.DEFINE_bool("amp", True, help="Automatic mixed precision flag.")
flags.DEFINE_integer("batch_size", 128, help="Batch size.")
flags.DEFINE_integer("epochs", 10, help="Epochs.")
FLAGS = flags.FLAGS
class TrainState(NamedTuple):
params: hk.Params
opt_state: optax.OptState
loss_scale: jmp.LossScale
class CNN(hk.Module):
def __init__(self, num_classes: int) -> None:
super().__init__()
self.conv1 = hk.Conv2D(output_channels=32, kernel_shape=(3, 3))
self.conv2 = hk.Conv2D(output_channels=64, kernel_shape=(3, 3))
self.dense1 = hk.Linear(output_size=256)
self.dense2 = hk.Linear(output_size=num_classes)
self.pool = hk.AvgPool(window_shape=(2, 2), strides=(2, 2), padding="VALID")
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
x = self.conv1(x)
x = jax.nn.relu(x)
x = self.pool(x)
x = self.conv2(x)
x = jax.nn.relu(x)
x = self.pool(x)
x = x.reshape((x.shape[0], -1)) # flatten
x = self.dense1(x)
x = jax.nn.relu(x)
x = self.dense2(x)
return x
def _forward(x: jnp.ndarray) -> jnp.ndarray:
net = CNN(num_classes=10)
return net(x)
def make_optimizer() -> optax.GradientTransformation:
return optax.adam(1e-3)
def compute_loss(logits: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
one_hot_labels = jax.nn.one_hot(labels, 10)
loss = optax.softmax_cross_entropy(logits, one_hot_labels).mean()
return loss
def train_step(
train_state: TrainState,
images: jnp.ndarray,
labels: jnp.ndarray,
forward: hk.Transformed,
) -> Tuple[TrainState, Dict[str, jnp.ndarray]]:
params, opt_state, loss_scale = train_state
# lossの定義
def loss_fn(params: hk.Params) -> Tuple[jnp.ndarray, jnp.ndarray]:
logits = forward.apply(params, images)
loss = compute_loss(logits, labels)
# Loss scaling
loss = loss_scale.scale(loss)
return loss, logits
# 勾配の算出
(loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)
# Unscaling
grads = loss_scale.unscale(grads)
# パラメータ更新
updates, new_opt_state = make_optimizer().update(grads, opt_state, params)
new_params = optax.apply_updates(params=params, updates=updates)
# DynamicLossScaleの場合、勾配のすべての要素が有限の場合のみパラメータの更新を行う
skip_nonfinite_updates = isinstance(loss_scale, jmp.DynamicLossScale)
if skip_nonfinite_updates:
grads_finite = jmp.all_finite(grads)
loss_scale = loss_scale.adjust(grads_finite)
new_params, new_opt_state = jmp.select_tree(
grads_finite,
(new_params, new_opt_state),
(params, opt_state),
)
new_train_state = TrainState(new_params, new_opt_state, loss_scale)
accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
scalars = {
"train_loss": loss_scale.unscale(loss),
"train_accuracy": accuracy,
"loss_scale": loss_scale.loss_scale,
}
scalars = jmp.cast_to_full(scalars)
return new_train_state, scalars
def eval_step(
params: hk.Params, images: np.ndarray, labels: np.ndarray, forward: hk.Transformed
) -> Dict[str, jnp.ndarray]:
logits = forward.apply(params, images)
loss = compute_loss(logits, labels)
accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
scalars = {"val_loss": loss, "val_accuracy": accuracy}
return scalars
def initial_state(
rng: jnp.ndarray, images: np.ndarray, forward: hk.Transformed
) -> TrainState:
params = forward.init(rng, images)
opt_state = make_optimizer().init(params)
loss_scale = jmp.DynamicLossScale(2.0 ** 15) if FLAGS.amp else jmp.NoOpLossScale()
return TrainState(params, opt_state, loss_scale)
def preprocess(
batch: Dict[str, tf.Tensor], dtype: tf.dtypes
) -> Tuple[tf.Tensor, tf.Tensor]:
images = batch["image"]
images = tf.image.resize(images=images, size=(128, 128))
images = tf.cast(images, dtype=dtype) / 255.0
labels = batch["label"]
return images, labels
def compute_mean_scalars_in_epoch(
batch_scalars: Dict[str, jnp.ndarray]
) -> Dict[str, float]:
batch_scalars_np = jax.device_get(batch_scalars)
epoch_scalars_np = {
k: np.mean([metrics[k] for metrics in batch_scalars_np])
for k in batch_scalars_np[0]
}
return epoch_scalars_np
def main(argv):
# モデルの演算精度のポリシーを設定
policy = jmp.Policy(
param_dtype=jnp.float32,
compute_dtype=jnp.float16 if FLAGS.amp else jnp.float32,
output_dtype=jnp.float32,
)
hk.mixed_precision.set_policy(CNN, policy)
# モデルの変換
forward = hk.without_apply_rng(hk.transform(_forward))
# データセットの作成
batch_size = FLAGS.batch_size
train_ds, test_ds = tfds.load("mnist", split=["train", "test"])
preprocess_fn = functools.partial(
preprocess, dtype=tf.float16 if FLAGS.amp else tf.float32
)
train_ds = (
train_ds.cache()
.batch(batch_size, drop_remainder=True)
.map(
preprocess_fn,
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)
.prefetch(True)
)
test_ds = (
test_ds.cache()
.batch(batch_size)
.map(
preprocess_fn,
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)
)
# パラメータ初期化のためのダミー入力を作成
dummy_images, _ = next(train_ds.as_numpy_iterator())
# パラメータ初期化
rng = jax.random.PRNGKey(0)
train_state = initial_state(rng, dummy_images, forward)
# Train/eval loop.
for i in range(FLAGS.epochs):
batch_scalars, eval_batch_scalars = [], []
epoch_scalars = {}
start = time.perf_counter()
# Training
for batch in train_ds.as_numpy_iterator():
images, labels = batch
train_state, scalars = jax.jit(train_step, static_argnums=3)(
train_state, images, labels, forward
)
batch_scalars.append(scalars)
training_time = time.perf_counter() - start
epoch_scalars.update(compute_mean_scalars_in_epoch(batch_scalars))
# Evaluation
params = train_state.params
for batch_eval in test_ds.as_numpy_iterator():
images, labels = batch_eval
eval_scalars = jax.jit(eval_step, static_argnums=3)(
params, images, labels, forward
)
eval_batch_scalars.append(eval_scalars)
epoch_scalars.update(compute_mean_scalars_in_epoch(eval_batch_scalars))
msg = f"Epoch {i} - {training_time:.5f}s"
for key, value in epoch_scalars.items():
msg += f" - {key}:{value:.5f}"
print(msg)
if __name__ == "__main__":
app.run(main)
おまけ
MNISTのオリジナル画像サイズ(28x28)で検証した場合
AMP | メモリ使用量 | 学習時間/epoch |
---|---|---|
ON | 1191 MB | 1.01 s |
OFF | 1191 MB | 0.77 s |
ということでAMP学習しない方が良いという結果になりました...
モデルサイズ、入力サイズともに小さい場合はAMPのためのオーバーヘッド分だけ遅くなっているということだと考えられます。
AMP学習も全てのケースで有効に働く訳ではなく、モデルサイズ、入力サイズなどに依存するので、導入する際にはきちんと効果があるか確認して、用法用量を正しく守って使いましょう。