0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【Rust】Burnのソースコードで学ぶADAM最適化アルゴリズム

Posted at

機械学習の世界で広く使われている最適化アルゴリズムの一つにADAM(Adaptive Moment Estimation)があります。この記事では、Rustで実装された機械学習フレームワーク「Burn」のソースコードを読み解きながら、ADAMの仕組みと実装方法について学んでいきましょう。

はじめに

ADAMは2014年に発表された最適化アルゴリズムで、その効率的な学習能力から現代のディープラーニングでは標準的に使用されています。特に、学習率の自動調整機能を持ち、パラメータごとに適応的な更新を行える点が大きな特徴です。

この記事を読むことで以下のことが理解できるようになります:

  • ADAMアルゴリズムの基本的な仕組み
  • Rustにおける実装方法と設計パターン
  • 実際のフレームワークでどのように使われているか

前提知識

この記事を理解するために以下の知識があると役立ちます:

  • 機械学習の基礎概念(勾配降下法、損失関数など)
  • Rustの基本的な文法
  • オプティマイザの役割の理解

ADAMアルゴリズムの数学的理解

ADAMの動作原理を数式を使って理解しましょう。機械学習では、損失関数 $L(\theta)$ を最小化するパラメータ $\theta$ を見つけることが目標です。

標準的な勾配降下法

最も基本的な勾配降下法では、パラメータを次のように更新します:

$\theta_{t+1} = \theta_t - \alpha \cdot \nabla L(\theta_t)$

ここで:

  • $\theta_t$ は時刻 $t$ でのパラメータ値
  • $\alpha$ は学習率
  • $\nabla L(\theta_t)$ は損失関数の勾配

この方法では、全パラメータに対して同じ学習率を適用するため、非効率的な場合があります。

ADAMの改良点

ADAMは次の2つの主要な改良を導入しています:

  1. モーメンタム: 過去の勾配情報を活用して更新方向を安定化
  2. 適応的学習率: パラメータごとに最適な学習率を自動調整

ADAMの更新則

ADAMは以下のステップでパラメータを更新します:

ステップ1: モーメントの計算

$$
m_t = β₁ · m_{t-1} + (1 - β₁) · g_t
$$

$$
v_t = β₂ · v_{t-1} + (1 - β₂) · g_t²
$$

ここで:

  • $g_t$ は現在の勾配 $\nabla L(\theta_t)$
  • $m_t$ は勾配の指数移動平均(一次モーメント)
  • $v_t$ は勾配の二乗の指数移動平均(二次モーメント)
  • $\beta_1$ と $\beta_2$ はハイパーパラメータ(一般的に $\beta_1 = 0.9$, $\beta_2 = 0.999$)

この式は「過去の情報を $\beta$ の割合で保持し、新しい情報を $1-\beta$ の割合で取り入れる」という意味です。

ステップ2: バイアス補正

学習の初期段階では、$m_t$と$v_t$は0で初期化されるため、値が小さく偏ってしまいます。これを補正します:

$$
m̂_t = m_t / (1 - β₁ᵗ)
$$

$$
v̂_t = v_t / (1 - β₂ᵗ)
$$

$t$ が大きくなるにつれて補正効果は小さくなるのが特徴です。

ステップ3: パラメータ更新

$$
θ_{t+1} = θ_t - α · m̂_t / (√v̂_t + ε)
$$

ここで $\epsilon$ は数値安定性のための小さな値(通常 $10^{-8}$ 程度)です。

ADAMオプティマイザの設定構造体

まずはBurnにおけるADAMの設定部分から見ていきましょう:

#[derive(Config)]
pub struct AdamConfig {
    /// Parameter for Adam.
    #[config(default = 0.9)]
    beta_1: f32,
    /// Parameter for Adam.
    #[config(default = 0.999)]
    beta_2: f32,
    /// A value required for numerical stability.
    #[config(default = 1e-5)]
    epsilon: f32,
    /// [Weight decay](WeightDecayConfig) config.
    weight_decay: Option<WeightDecayConfig>,
    /// [Gradient Clipping](GradientClippingConfig) config.
    grad_clipping: Option<GradientClippingConfig>,
}

このAdamConfig構造体からわかることは:

  1. beta_1とbeta_2: ADAMアルゴリズムの中心的なハイパーパラメータで、それぞれ一次と二次のモーメント更新の係数です。デフォルト値は論文で推奨されている値になっています。

  2. epsilon: 数値の安定性を保つための小さな値です。ゼロ除算を防ぐために使われます。

  3. weight_decay: 重み減衰の設定をオプションで指定できます。過学習を防ぐのに役立ちます。

  4. grad_clipping: 勾配クリッピングの設定もオプションで指定可能です。勾配爆発を防ぐために使われます。

Adamの構造体と状態

続いて、実際のAdam実装の構造体を見てみましょう:

#[derive(Clone)]
pub struct Adam {
    momentum: AdaptiveMomentum,
    weight_decay: Option<WeightDecay>,
}

#[derive(Record, Clone, new)]
pub struct AdamState<B: Backend, const D: usize> {
    /// The current adaptive momentum.
    pub momentum: AdaptiveMomentumState<B, D>,
}

Adamは「状態を持つ」オプティマイザです。AdamStateには現在のモーメンタム状態が保持され、反復ごとに更新されていきます。

ADAM更新ステップの実装

stepメソッドをみましょう:

fn step<const D: usize>(
    &self,
    lr: LearningRate,
    tensor: Tensor<B, D>,
    mut grad: Tensor<B, D>,
    state: Option<Self::State<D>>,
) -> (Tensor<B, D>, Option<Self::State<D>>) {
    let mut state_momentum = None;

    if let Some(state) = state {
        state_momentum = Some(state.momentum);
    }

    if let Some(weight_decay) = &self.weight_decay {
        grad = weight_decay.transform(grad, tensor.clone());
    }

    let (grad, state_momentum) = self.momentum.transform(grad, state_momentum);

    let state = AdamState::new(state_momentum);
    let delta = grad.mul_scalar(lr);

    (tensor - delta, Some(state))
}
  1. まず既存の状態があれば取り出します
  2. 重み減衰が設定されていれば勾配を変換します
  3. momentum.transformを呼び出してADAMの中核となる適応的モーメント推定を行います
  4. 新しい状態を生成し、学習率を掛けた勾配で重みを更新します

適応的モーメントの実装

重要な適応的モーメント推定のロジックはAdaptiveMomentum構造体のtransformメソッドに実装されています:

pub fn transform<B: Backend, const D: usize>(
    &self,
    grad: Tensor<B, D>,
    momentum_state: Option<AdaptiveMomentumState<B, D>>,
) -> (Tensor<B, D>, AdaptiveMomentumState<B, D>) {
    let state = if let Some(mut state) = momentum_state {
        let factor = 1.0 - self.beta_1;
        state.moment_1 = state
            .moment_1
            .mul_scalar(self.beta_1)
            .add(grad.clone().mul_scalar(factor));

        let factor = 1.0 - self.beta_2;
        state.moment_2 = state
            .moment_2
            .mul_scalar(self.beta_2)
            .add(grad.powf_scalar(2.0).mul_scalar(factor));

        state.time += 1;

        state
    } else {
        let factor = 1.0 - self.beta_1;
        let moment_1 = grad.clone().mul_scalar(factor);

        let factor = 1.0 - self.beta_2;
        let moment_2 = grad.powf_scalar(2.0).mul_scalar(factor);

        AdaptiveMomentumState::new(1, moment_1, moment_2)
    };

    let time = (state.time as i32).elem();
    let moment_1_corrected = state
        .moment_1
        .clone()
        .div_scalar(1f32 - self.beta_1.powi(time));
    let moment_2_corrected = state
        .moment_2
        .clone()
        .div_scalar(1f32 - self.beta_2.powi(time));

    let grad = moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon));

    (grad, state)
}

ここでADAMアルゴリズムが実装されています:

  1. 一次モーメント(平均)の更新

    state.moment_1 = state.moment_1.mul_scalar(self.beta_1).add(grad.clone().mul_scalar(factor));
    

    これは勾配の指数移動平均を計算しています。

  2. 二次モーメント(分散)の更新

    state.moment_2 = state.moment_2.mul_scalar(self.beta_2).add(grad.powf_scalar(2.0).mul_scalar(factor));
    

    これは勾配の二乗の指数移動平均を計算しています。

  3. バイアス補正

    let moment_1_corrected = state.moment_1.clone().div_scalar(1f32 - self.beta_1.powi(time));
    let moment_2_corrected = state.moment_2.clone().div_scalar(1f32 - self.beta_2.powi(time));
    

    初期ステップでのモーメント推定のバイアスを補正します。

  4. 適応的学習率の計算

    let grad = moment_1_corrected.div(moment_2_corrected.sqrt().add_scalar(self.epsilon));
    

    一次モーメントを二次モーメントの平方根で割ることで、パラメータごとに適応した更新量を得ます。

なぜこの割り算が重要なのか?

1. 勾配の大きさの正規化

パラメータ更新を考える時、単に勾配(またはその移動平均である一次モーメント)をそのまま使うと問題があります。勾配の大きさはパラメータによって大きく異なる可能性があるからです。

例えば:

  • あるパラメータの勾配が常に大きい(例:10.0前後)
  • 別のパラメータの勾配が常に小さい(例:0.01前後)

この場合、同じ学習率を使うと、勾配が大きいパラメータは過剰に更新され、小さいパラメータは十分に更新されません。

2. 二次モーメントの意味

$v_t$ (二次モーメント) は勾配の二乗の指数移動平均です。これは、「このパラメータの勾配はどれくらい変動しているか」という情報を持っています:

  • $v_t$ が大きい → そのパラメータの勾配は大きく、またはよく変動している
  • $v_t$ が小さい → そのパラメータの勾配は小さく、比較的安定している

3. 割り算の効果

$\hat{m}_t$ を $\sqrt{\hat{v}_t}$ で割ることで何が起きるか考えましょう:

  • 勾配変動が大きいパラメータ($v_t$ が大きい)では、$\hat{m}_t / \sqrt{\hat{v}_t}$ は小さくなります → 更新が慎重になる

  • 勾配変動が安定しているパラメータ($v_t$ が小さい)では、$\hat{m}_t / \sqrt{\hat{v}_t}$​​ は大きくなります → 更新が積極的になる

ADAMアルゴリズム実装を理解する

  1. 一次モーメント $m_t$(方向)

    • 過去の勾配の方向を記憶することで、ノイズの影響を減らし、より安定した方向に進みます
    • 局所的な変動に惑わされにくくなります
  2. 二次モーメント $v_t$(スケール)

    • 勾配の二乗値の履歴を追跡し、パラメータごとの更新スケールを調整します
    • 頻繁に大きな勾配を持つパラメータ($v_t$ が大きい)は更新量が小さくなり、安定したパラメータは大きく更新されます
    • これは「道の状況に応じて速度を調整する」ようなもので、急な斜面では慎重に、緩やかな斜面では大胆に進みます
  3. バイアス補正

    • 学習初期では、$m_t$ と $v_t$ は0に初期化されるため、過小評価される傾向があります
    • 補正係数 $(1-\beta^t)$ で割ることで、この初期バイアスを解消します
    • 特に学習の初期段階で重要な役割を果たします
  4. 適応的更新量 $\alpha \cdot \hat{m}_t / (\sqrt{\hat{v}_t} + \epsilon)$

    • 各パラメータに最適な更新量を自動的に決定します
    • 二次モーメントが大きいパラメータは小さく更新され、小さいパラメータは大きく更新されます

ADAMの具体例

パラメータ1つの簡単な例で考えてみます。3ステップの勾配が $g_1 = 4.0$, $g_2 = 2.0$, $g_3 = 1.0$ だったとします。
$\beta_1 = 0.9$, $\beta_2 = 0.999$, $\epsilon = 10^{-8}$, $\alpha = 0.001$ として計算すると:

ステップ1:

  • $m_1 = 0.4$, $v_1 = 0.016$
  • バイアス補正後: $\hat{m}_1 = 4.0$, $\hat{v}_1 = 16.0$
  • 更新量: $0.00100$

ステップ2:

  • $m_2 = 0.56$, $v_2 = 0.02$
  • バイアス補正後: $\hat{m}_2 = 2.95$, $\hat{v}_2 = 10.0$
  • 更新量: $0.00093$

ステップ3:

  • $m_3 = 0.604$, $v_3 = 0.021$
  • バイアス補正後: $\hat{m}_3 = 2.12$, $\hat{v}_3 = 7.0$
  • 更新量: $0.00080$

勾配が小さくなるにつれて、ADAMは自動的に更新量を調整します。これが、ADAMが効果的な理由です:

  1. パラメータごとの適応的学習率:各パラメータの勾配の履歴に基づいて学習率を調整します。

  2. モーメンタムと二次モーメントの組み合わせ:一次モーメントによる慣性と二次モーメントによる適応的なスケーリングを兼ね備えています。

  3. バイアス補正:初期ステップでの推定値のバイアスを補正することで、学習の初期段階での性能を向上させます。

使用例

BurnフレームワークでADAMを使う際は、以下のようにシンプルに設定できます:

let optimizer = AdamConfig::new()
    .with_beta_1(0.9)
    .with_beta_2(0.999)
    .with_epsilon(1e-8)
    .with_weight_decay(Some(WeightDecayConfig::new(0.01)))
    .init();

これで、モデルのトレーニングにADAMオプティマイザを使用する準備が整います。

よくある質問

Q: SGDやモーメンタムと比べて、ADAMはどのような場合に有利ですか?

A: ADAMは以下のような場合に特に効果的です:

  • 大規模なデータセットや複雑なモデルの学習
  • スパースな勾配が発生するようなケース(自然言語処理など)
  • ハイパーパラメータのチューニングに時間をかけられない場合
  • 非定常的な目的関数(時間とともに変化する問題)の最適化

一方、単純な問題やデータセットが小さい場合は、SGDやモーメンタムの方が良いパフォーマンスを示すこともあります。

Q: ADAMのハイパーパラメータはどのように設定すればよいですか?

A: ADAMは比較的ハイパーパラメータの設定に robust なのが特徴ですが、一般的には:

  • 学習率 $\alpha$: 0.001(ネットワークの構造によって調整が必要)
  • $\beta_1$: 0.9(一次モーメントの減衰率)
  • $\beta_2$: 0.999(二次モーメントの減衰率)
  • $\epsilon$: $10^{-8}$(数値安定性のための定数)

これらは論文で推奨されている値で、多くの場合そのまま使えますが、特に学習率は問題に応じて調整する価値があります。

Q: ADAMの欠点や注意点はありますか?

A: ADAMにもいくつかの課題があります:

  • 一部の問題では汎化性能がSGDより劣ることがある
  • メモリ使用量が大きい(各パラメータに対して2つの状態を保持)
  • 学習の後半でSGDよりも精度が出ない場合がある
  • 一部の問題では収束が保証されない場合がある

こうした課題に対処するために、AMSGrad, AdamWなどの派生アルゴリズムも提案されています。

まとめ

この記事では、RustのBurnフレームワークのソースコードを通して、ADAM最適化アルゴリズムの実装と理論について学びました。ADAMの強みは:

  • パラメータごとに適応的な学習率を持つ
  • モーメンタムと二次モーメントの利点を組み合わせている
  • バイアス補正により初期ステップでも効率的に学習できる
  • ハイパーパラメータの設定に敏感ではない
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?