機械学習の世界で広く使われている最適化アルゴリズムの一つにADAM(Adaptive Moment Estimation)があります。この記事では、Rustで実装された機械学習フレームワーク「Burn」のソースコードを読み解きながら、ADAMの仕組みと実装方法について学んでいきましょう。
はじめに
ADAMは2014年に発表された最適化アルゴリズムで、その効率的な学習能力から現代のディープラーニングでは標準的に使用されています。特に、学習率の自動調整機能を持ち、パラメータごとに適応的な更新を行える点が大きな特徴です。
この記事を読むことで以下のことが理解できるようになります:
- ADAMアルゴリズムの基本的な仕組み
- Rustにおける実装方法と設計パターン
- 実際のフレームワークでどのように使われているか
前提知識
この記事を理解するために以下の知識があると役立ちます:
- 機械学習の基礎概念(勾配降下法、損失関数など)
- Rustの基本的な文法
- オプティマイザの役割の理解
ADAMアルゴリズムの数学的理解
ADAMの動作原理を数式を使って理解しましょう。機械学習では、損失関数 $L(\theta)$ を最小化するパラメータ $\theta$ を見つけることが目標です。
標準的な勾配降下法
最も基本的な勾配降下法では、パラメータを次のように更新します:
$\theta_{t+1} = \theta_t - \alpha \cdot \nabla L(\theta_t)$
ここで:
- $\theta_t$ は時刻 $t$ でのパラメータ値
- $\alpha$ は学習率
- $\nabla L(\theta_t)$ は損失関数の勾配
この方法では、全パラメータに対して同じ学習率を適用するため、非効率的な場合があります。
ADAMの改良点
ADAMは次の2つの主要な改良を導入しています:
- モーメンタム: 過去の勾配情報を活用して更新方向を安定化
- 適応的学習率: パラメータごとに最適な学習率を自動調整
ADAMの更新則
ADAMは以下のステップでパラメータを更新します:
ステップ1: モーメントの計算
$$
m_t = β₁ · m_{t-1} + (1 - β₁) · g_t
$$
$$
v_t = β₂ · v_{t-1} + (1 - β₂) · g_t²
$$
ここで:
- $g_t$ は現在の勾配 $\nabla L(\theta_t)$
- $m_t$ は勾配の指数移動平均(一次モーメント)
- $v_t$ は勾配の二乗の指数移動平均(二次モーメント)
- $\beta_1$ と $\beta_2$ はハイパーパラメータ(一般的に $\beta_1 = 0.9$, $\beta_2 = 0.999$)
この式は「過去の情報を $\beta$ の割合で保持し、新しい情報を $1-\beta$ の割合で取り入れる」という意味です。
ステップ2: バイアス補正
学習の初期段階では、$m_t$と$v_t$は0で初期化されるため、値が小さく偏ってしまいます。これを補正します:
$$
m̂_t = m_t / (1 - β₁ᵗ)
$$
$$
v̂_t = v_t / (1 - β₂ᵗ)
$$
$t$ が大きくなるにつれて補正効果は小さくなるのが特徴です。
ステップ3: パラメータ更新
$$
θ_{t+1} = θ_t - α · m̂_t / (√v̂_t + ε)
$$
ここで $\epsilon$ は数値安定性のための小さな値(通常 $10^{-8}$ 程度)です。
ADAMオプティマイザの設定構造体
まずはBurnにおけるADAMの設定部分から見ていきましょう:
#[derive(Config)]
pub struct AdamConfig {
/// Parameter for Adam.
#[config(default = 0.9)]
beta_1: f32,
/// Parameter for Adam.
#[config(default = 0.999)]
beta_2: f32,
/// A value required for numerical stability.
#[config(default = 1e-5)]
epsilon: f32,
/// [Weight decay](WeightDecayConfig) config.
weight_decay: Option<WeightDecayConfig>,
/// [Gradient Clipping](GradientClippingConfig) config.
grad_clipping: Option<GradientClippingConfig>,
}
このAdamConfig
構造体からわかることは:
-
beta_1とbeta_2: ADAMアルゴリズムの中心的なハイパーパラメータで、それぞれ一次と二次のモーメント更新の係数です。デフォルト値は論文で推奨されている値になっています。
-
epsilon: 数値の安定性を保つための小さな値です。ゼロ除算を防ぐために使われます。
-
weight_decay: 重み減衰の設定をオプションで指定できます。過学習を防ぐのに役立ちます。
-
grad_clipping: 勾配クリッピングの設定もオプションで指定可能です。勾配爆発を防ぐために使われます。
Adamの構造体と状態
続いて、実際のAdam実装の構造体を見てみましょう:
#[derive(Clone)]
pub struct Adam {
momentum: AdaptiveMomentum,
weight_decay: Option<WeightDecay>,
}
#[derive(Record, Clone, new)]
pub struct AdamState<B: Backend, const D: usize> {
/// The current adaptive momentum.
pub momentum: AdaptiveMomentumState<B, D>,
}
Adamは「状態を持つ」オプティマイザです。AdamState
には現在のモーメンタム状態が保持され、反復ごとに更新されていきます。
ADAM更新ステップの実装
step
メソッドをみましょう:
fn step<const D: usize>(
&self,
lr: LearningRate,
tensor: Tensor<B, D>,
mut grad: Tensor<B, D>,
state: Option<Self::State<D>>,
) -> (Tensor<B, D>, Option<Self::State<D>>) {
let mut state_momentum = None;
if let Some(state) = state {
state_momentum = Some(state.momentum);
}
if let Some(weight_decay) = &self.weight_decay {
grad = weight_decay.transform(grad, tensor.clone());
}
let (grad, state_momentum) = self.momentum.transform(grad, state_momentum);
let state = AdamState::new(state_momentum);
let delta = grad.mul_scalar(lr);
(tensor - delta, Some(state))
}
- まず既存の状態があれば取り出します
- 重み減衰が設定されていれば勾配を変換します
-
momentum.transform
を呼び出してADAMの中核となる適応的モーメント推定を行います - 新しい状態を生成し、学習率を掛けた勾配で重みを更新します
適応的モーメントの実装
重要な適応的モーメント推定のロジックはAdaptiveMomentum
構造体のtransform
メソッドに実装されています:
pub fn transform<B: Backend, const D: usize>(
&self,
grad: Tensor<B, D>,
momentum_state: Option<AdaptiveMomentumState<B, D>>,
) -> (Tensor<B, D>, AdaptiveMomentumState<B, D>) {
let state = if let Some(mut state) = momentum_state {
let factor = 1.0 - self.beta_1;
state.moment_1 = state
.moment_1
.mul_scalar(self.beta_1)
.add(grad.clone().mul_scalar(factor));
let factor = 1.0 - self.beta_2;
state.moment_2 = state
.moment_2
.mul_scalar(self.beta_2)
.add(grad.powf_scalar(2.0).mul_scalar(factor));
state.time += 1;
state
} else {
let factor = 1.0 - self.beta_1;
let moment_1 = grad.clone().mul_scalar(factor);
let factor = 1.0 - self.beta_2;
let moment_2 = grad.powf_scalar(2.0).mul_scalar(factor);
AdaptiveMomentumState::new(1, moment_1, moment_2)
};
let time = (state.time as i32).elem();
let moment_1_corrected = state
.moment_1
.clone()
.div_scalar(1f32 - self.beta_1.powi(time));
let moment_2_corrected = state
.moment_2
.clone()
.div_scalar(1f32 - self.beta_2.powi(time));
let grad = moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon));
(grad, state)
}
ここでADAMアルゴリズムが実装されています:
-
一次モーメント(平均)の更新:
state.moment_1 = state.moment_1.mul_scalar(self.beta_1).add(grad.clone().mul_scalar(factor));
これは勾配の指数移動平均を計算しています。
-
二次モーメント(分散)の更新:
state.moment_2 = state.moment_2.mul_scalar(self.beta_2).add(grad.powf_scalar(2.0).mul_scalar(factor));
これは勾配の二乗の指数移動平均を計算しています。
-
バイアス補正:
let moment_1_corrected = state.moment_1.clone().div_scalar(1f32 - self.beta_1.powi(time)); let moment_2_corrected = state.moment_2.clone().div_scalar(1f32 - self.beta_2.powi(time));
初期ステップでのモーメント推定のバイアスを補正します。
-
適応的学習率の計算:
let grad = moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon));
一次モーメントを二次モーメントの平方根で割ることで、パラメータごとに適応した更新量を得ます。
なぜこの割り算が重要なのか?
1. 勾配の大きさの正規化
パラメータ更新を考える時、単に勾配(またはその移動平均である一次モーメント)をそのまま使うと問題があります。勾配の大きさはパラメータによって大きく異なる可能性があるからです。
例えば:
- あるパラメータの勾配が常に大きい(例:10.0前後)
- 別のパラメータの勾配が常に小さい(例:0.01前後)
この場合、同じ学習率を使うと、勾配が大きいパラメータは過剰に更新され、小さいパラメータは十分に更新されません。
2. 二次モーメントの意味
$v_t$ (二次モーメント) は勾配の二乗の指数移動平均です。これは、「このパラメータの勾配はどれくらい変動しているか」という情報を持っています:
- $v_t$ が大きい → そのパラメータの勾配は大きく、またはよく変動している
- $v_t$ が小さい → そのパラメータの勾配は小さく、比較的安定している
3. 割り算の効果
$\hat{m}_t$ を $\sqrt{\hat{v}_t}$ で割ることで何が起きるか考えましょう:
-
勾配変動が大きいパラメータ($v_t$ が大きい)では、$\hat{m}_t / \sqrt{\hat{v}_t}$ は小さくなります → 更新が慎重になる
-
勾配変動が安定しているパラメータ($v_t$ が小さい)では、$\hat{m}_t / \sqrt{\hat{v}_t}$ は大きくなります → 更新が積極的になる
ADAMアルゴリズム実装を理解する
-
一次モーメント $m_t$(方向):
- 過去の勾配の方向を記憶することで、ノイズの影響を減らし、より安定した方向に進みます
- 局所的な変動に惑わされにくくなります
-
二次モーメント $v_t$(スケール):
- 勾配の二乗値の履歴を追跡し、パラメータごとの更新スケールを調整します
- 頻繁に大きな勾配を持つパラメータ($v_t$ が大きい)は更新量が小さくなり、安定したパラメータは大きく更新されます
- これは「道の状況に応じて速度を調整する」ようなもので、急な斜面では慎重に、緩やかな斜面では大胆に進みます
-
バイアス補正:
- 学習初期では、$m_t$ と $v_t$ は0に初期化されるため、過小評価される傾向があります
- 補正係数 $(1-\beta^t)$ で割ることで、この初期バイアスを解消します
- 特に学習の初期段階で重要な役割を果たします
-
適応的更新量 $\alpha \cdot \hat{m}_t / (\sqrt{\hat{v}_t} + \epsilon)$:
- 各パラメータに最適な更新量を自動的に決定します
- 二次モーメントが大きいパラメータは小さく更新され、小さいパラメータは大きく更新されます
ADAMの具体例
パラメータ1つの簡単な例で考えてみます。3ステップの勾配が $g_1 = 4.0$, $g_2 = 2.0$, $g_3 = 1.0$ だったとします。
$\beta_1 = 0.9$, $\beta_2 = 0.999$, $\epsilon = 10^{-8}$, $\alpha = 0.001$ として計算すると:
ステップ1:
- $m_1 = 0.4$, $v_1 = 0.016$
- バイアス補正後: $\hat{m}_1 = 4.0$, $\hat{v}_1 = 16.0$
- 更新量: $0.00100$
ステップ2:
- $m_2 = 0.56$, $v_2 = 0.02$
- バイアス補正後: $\hat{m}_2 = 2.95$, $\hat{v}_2 = 10.0$
- 更新量: $0.00093$
ステップ3:
- $m_3 = 0.604$, $v_3 = 0.021$
- バイアス補正後: $\hat{m}_3 = 2.12$, $\hat{v}_3 = 7.0$
- 更新量: $0.00080$
勾配が小さくなるにつれて、ADAMは自動的に更新量を調整します。これが、ADAMが効果的な理由です:
-
パラメータごとの適応的学習率:各パラメータの勾配の履歴に基づいて学習率を調整します。
-
モーメンタムと二次モーメントの組み合わせ:一次モーメントによる慣性と二次モーメントによる適応的なスケーリングを兼ね備えています。
-
バイアス補正:初期ステップでの推定値のバイアスを補正することで、学習の初期段階での性能を向上させます。
使用例
BurnフレームワークでADAMを使う際は、以下のようにシンプルに設定できます:
let optimizer = AdamConfig::new()
.with_beta_1(0.9)
.with_beta_2(0.999)
.with_epsilon(1e-8)
.with_weight_decay(Some(WeightDecayConfig::new(0.01)))
.init();
これで、モデルのトレーニングにADAMオプティマイザを使用する準備が整います。
よくある質問
Q: SGDやモーメンタムと比べて、ADAMはどのような場合に有利ですか?
A: ADAMは以下のような場合に特に効果的です:
- 大規模なデータセットや複雑なモデルの学習
- スパースな勾配が発生するようなケース(自然言語処理など)
- ハイパーパラメータのチューニングに時間をかけられない場合
- 非定常的な目的関数(時間とともに変化する問題)の最適化
一方、単純な問題やデータセットが小さい場合は、SGDやモーメンタムの方が良いパフォーマンスを示すこともあります。
Q: ADAMのハイパーパラメータはどのように設定すればよいですか?
A: ADAMは比較的ハイパーパラメータの設定に robust なのが特徴ですが、一般的には:
- 学習率 $\alpha$: 0.001(ネットワークの構造によって調整が必要)
- $\beta_1$: 0.9(一次モーメントの減衰率)
- $\beta_2$: 0.999(二次モーメントの減衰率)
- $\epsilon$: $10^{-8}$(数値安定性のための定数)
これらは論文で推奨されている値で、多くの場合そのまま使えますが、特に学習率は問題に応じて調整する価値があります。
Q: ADAMの欠点や注意点はありますか?
A: ADAMにもいくつかの課題があります:
- 一部の問題では汎化性能がSGDより劣ることがある
- メモリ使用量が大きい(各パラメータに対して2つの状態を保持)
- 学習の後半でSGDよりも精度が出ない場合がある
- 一部の問題では収束が保証されない場合がある
こうした課題に対処するために、AMSGrad, AdamWなどの派生アルゴリズムも提案されています。
まとめ
この記事では、RustのBurnフレームワークのソースコードを通して、ADAM最適化アルゴリズムの実装と理論について学びました。ADAMの強みは:
- パラメータごとに適応的な学習率を持つ
- モーメンタムと二次モーメントの利点を組み合わせている
- バイアス補正により初期ステップでも効率的に学習できる
- ハイパーパラメータの設定に敏感ではない