0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【Rust】Burnのソースコードから学ぶ「バッチ正規化」

Last updated at Posted at 2025-05-03

ディープラーニングの学習を安定させ高速化する重要な技術「バッチ正規化(Batch Normalization)」。この記事では、Rust言語で実装されたディープラーニングフレームワーク「Burn」のソースコードを通じて、バッチ正規化の仕組みを解説します。

はじめに

バッチ正規化は2015年にGoogleの研究者によって提案された手法で、ニューラルネットワークの中間層の出力を正規化することにより、学習の安定性と速度を向上させる技術です。この記事では、Burnのソースコードを読み解きながら、バッチ正規化について理解します。

バッチ正規化の基本概念

内部共変量シフト問題

ディープラーニングのモデルを学習させる際、各層の入力分布が学習中に変化する「内部共変量シフト(Internal Covariate Shift)」という問題が発生します。この問題により学習が不安定になったり、時間がかかったりすることがあります。

バッチ正規化はこの問題に対処するために考案された手法で、ニューラルネットワークの中間層の出力を適切に調整します。

バッチ正規化の効果

バッチ正規化には主に次のような効果があります:

  1. 学習の高速化 - より大きな学習率を使えるようになる
  2. 初期値への依存性の軽減 - 重みの初期化に敏感でなくなる
  3. 勾配消失問題の軽減 - 勾配が適切に伝播しやすくなる
  4. 正則化効果 - 過学習を抑制する効果がある
  5. より深いネットワークの学習を可能に

バッチ正規化の数学的定義

バッチ正規化は、次のような数学的処理を行います:

  1. ミニバッチごとに各特徴量の平均μと分散σ²を計算
  2. 入力値xを正規化: $(x - \mu) / \sqrt{\sigma^2 + \epsilon}$
  3. スケール係数γと平行移動係数βで調整: $\gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$

ここで:

  • μはミニバッチの平均
  • σ²はミニバッチの分散
  • εは数値的安定性のための小さな定数(ゼロ除算防止)
  • γとβは学習可能なパラメータ

γとβは学習可能なパラメータであり、各層が最適な出力分布を学習できるようにします(後述)。

Burnにおけるバッチ正規化の実装

それでは、Burnのソースコードを見ながら、バッチ正規化がどのように実装されているのか見ていきましょう。

設定と構造体の定義

最初に、バッチ正規化の設定を行うための構造体と、実際のバッチ正規化を行う構造体の定義を見てみましょう:

// バッチ正規化の設定を行うための構造体
#[derive(Config, Debug)]
pub struct BatchNormConfig {
    /// 特徴量の数(チャネル数)
    pub num_features: usize,
    /// 数値的安定性のための小さな定数。デフォルト: 1e-5
    #[config(default = 1e-5)]
    pub epsilon: f64,
    /// 移動平均を更新する際のモメンタム。デフォルト: 0.1
    #[config(default = 0.1)]
    pub momentum: f64,
}

// バッチ正規化を実行するモジュール
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct BatchNorm<B: Backend, const D: usize> {
    /// 学習可能なスケール係数γ
    pub gamma: Param<Tensor<B, 1>>,
    /// 学習可能なシフト係数β
    pub beta: Param<Tensor<B, 1>>,
    /// 推論時に使用する移動平均
    pub running_mean: RunningState<Tensor<B, 1>>,
    /// 推論時に使用する移動分散
    pub running_var: RunningState<Tensor<B, 1>>,
    /// 移動平均を更新するモメンタム
    pub momentum: f64,
    /// 数値的安定性のための小さな定数
    pub epsilon: f64,
}

BatchNormConfig構造体では、バッチ正規化に必要な設定を定義しています:

  • num_features: 正規化する特徴量の数(チャネル数)
  • epsilon: 数値的安定性のための小さな定数(デフォルト値は1e-5)
  • momentum: 移動平均を更新する際のモメンタム(デフォルト値は0.1)

BatchNorm構造体は、実際のバッチ正規化の処理を行うモジュールです:

  • gamma: 学習可能なスケール係数(正規化後の分散を調整)
  • beta: 学習可能なシフト係数(正規化後の平均を調整)
  • running_mean: 推論時に使用する特徴量の移動平均
  • running_var: 推論時に使用する特徴量の分散の移動平均
  • momentum: 移動平均を更新するモメンタム
  • epsilon: 数値的安定性のための小さな定数

BatchNorm構造体はジェネリックに設計されています。B: Backendによって様々なバックエンド(CPU、GPU)で動作するよう抽象化され、定数ジェネリックパラメータDにより、1次元、2次元などの異なる次元のデータに対応できます

初期化メソッド

次に、BatchNormConfigからBatchNormを初期化するメソッドを見てみましょう:

impl BatchNormConfig {
    /// バッチ正規化モジュールを初期化する
    pub fn init<B: Backend, const D: usize>(&self, device: &B::Device) -> BatchNorm<B, D> {
        // γを1で初期化(スケールを変えない初期状態)
        let gamma = Initializer::Ones.init([self.num_features], device);
        // βを0で初期化(シフトなしの初期状態)
        let beta = Initializer::Zeros.init([self.num_features], device);

        // 平均は0、分散は1で初期化
        let running_mean = Tensor::zeros([self.num_features], device);
        let running_var = Tensor::ones([self.num_features], device);

        BatchNorm {
            gamma,
            beta,
            running_mean: RunningState::new(running_mean),
            running_var: RunningState::new(running_var),
            momentum: self.momentum,
            epsilon: self.epsilon,
        }
    }
}

このメソッドでは次のことを行っています:

  1. gammaを1で初期化 - 最初は正規化された値のスケールを変えない
  2. betaを0で初期化 - 最初は正規化された値をシフトしない
  3. running_meanを0で初期化 - 推論時の平均の初期値
  4. running_varを1で初期化 - 推論時の分散の初期値

これらの初期値により、バッチ正規化の初期段階では単純に入力を標準化するだけの動作をします。学習が進むと、gammabetaは最適な値に調整され、running_meanrunning_varは訓練データ全体の統計情報を反映するように更新されます。

順伝播(forward)メソッド

バッチ正規化の中心となる処理を行うforwardメソッドを見てみましょう:

impl<const D: usize, B: Backend> BatchNorm<B, D> {
    /// 入力テンソルに対して順伝播を実行する
    ///
    /// # 入出力の形状
    /// - 入力: [batch_size, channels, ...]
    /// - 出力: [batch_size, channels, ...]
    pub fn forward<const DI: usize>(&self, input: Tensor<B, DI>) -> Tensor<B, DI> {
        // 入力テンソルの次元チェック
        if D + 2 != DI {
            panic!(
                "BatchNorm{}D can only be applied on tensors of size {} with the following shape \
                 [batch_size, channels, ...], received {}D tensor",
                D,
                D + 2,
                DI
            );
        }

        // 訓練モードと推論モードで処理を分ける
        match B::ad_enabled() {
            true => self.forward_train(input),  // 訓練モード
            false => self.forward_inference(input),  // 推論モード
        }
    }
}

このメソッドでは、まず入力テンソルの次元を確認し、バッチ正規化の前提条件(形状が[batch_size, channels, ...])を満たしているかチェックします。その後、現在の状態(訓練または推論)に応じて異なる処理を呼び出します:

  • 自動微分が有効(B::ad_enabled()true)なら訓練モードで実行
  • そうでなければ推論モードで実行

このように同じモジュールで訓練時と推論時の両方に対応できるのが特徴です。

推論時の処理

推論時には、訓練中に計算された移動平均と移動分散を使用します:

fn forward_inference<const DI: usize>(&self, input: Tensor<B, DI>) -> Tensor<B, DI> {
    let device = input.device();
    let channels = input.dims()[1];
    
    // 保存されている移動平均と移動分散を取得
    let mean = self.running_mean.value().to_device(&device);
    let var = self.running_var.value().to_device(&device);

    // ブロードキャスト用に形状を調整
    let mut shape = [1; DI];
    shape[1] = channels;

    // 共通の正規化処理を呼び出す
    self.forward_shared(input, mean.reshape(shape), var.reshape(shape))
}

この処理では:

  1. 保存されている移動平均と移動分散を取得
  2. 入力テンソルと同じデバイスに転送(CPU/GPU間の移動がある場合)
  3. 入力テンソルの形状に合わせて形状を変形
  4. 共通の正規化処理を呼び出す

推論時には、現在のバッチだけでなく、訓練時に蓄積された統計情報を使うことで、バッチサイズに依存しない安定した出力を生成できます。これは特に小さなバッチサイズや、バッチサイズ1での推論で重要です。

訓練時の処理

訓練時には、現在のミニバッチの統計量を計算し、同時に移動平均を更新します:

fn forward_train<const DI: usize>(&self, input: Tensor<B, DI>) -> Tensor<B, DI> {
    let device = input.device();
    let dims = input.dims();
    let batch_size = dims[0];
    let channels = dims[1];

    // 統計計算用の形状準備
    let mut shape_unsqueeze = [1; DI];
    let mut flatten_size = batch_size;
    shape_unsqueeze[1] = channels;

    for dim in dims.iter().take(DI).skip(2) {
        flatten_size *= dim;
    }

    // 現在のバッチの平均を計算
    let mean = input
        .clone()
        .swap_dims(0, 1)
        .reshape([channels, flatten_size])
        .mean_dim(1)
        .reshape(shape_unsqueeze);

    // 現在のバッチの分散を計算
    let var = input
        .clone()
        .sub(mean.clone())
        .powf_scalar(2.0)
        .swap_dims(0, 1)
        .reshape([channels, flatten_size])
        .mean_dim(1)
        .reshape(shape_unsqueeze);

    // 移動平均・移動分散を取得
    let running_mean = self.running_mean.value_sync().to_device(&device);
    let running_var = self.running_var.value_sync().to_device(&device);

    // 移動平均を更新
    let running_mean = running_mean.mul_scalar(1.0 - self.momentum).add(
        mean.clone()
            .detach()
            .mul_scalar(self.momentum)
            .reshape([channels]),
    );
    
    // 移動分散を更新
    let running_var = running_var.mul_scalar(1.0 - self.momentum).add(
        var.clone()
            .detach()
            .mul_scalar(self.momentum)
            .reshape([channels]),
    );

    // 更新した値を保存
    self.running_mean.update(running_mean.detach());
    self.running_var.update(running_var.detach());

    // 共通の正規化処理を呼び出す
    self.forward_shared(input, mean, var)
}

この処理は次のステップで行われます:

  1. 現在のミニバッチから平均を計算

    • 入力テンソルのバッチ次元とチャネル次元を入れ替え
    • チャネルごとに他の次元をフラット化
    • 各チャネルの平均を計算
  2. 現在のミニバッチから分散を計算

    • 入力から平均を引いて差を計算
    • 差を2乗
    • チャネルごとに平均を取る
  3. 移動平均と移動分散の更新

    • 既存の値に(1 - momentum)を掛け、新しい値にmomentumを掛けて加算
    • これにより新しい値を少しずつ反映しつつ、過去の値も考慮した平滑化が行われる
    • 勾配計算から切り離す(detach)(移動平均は学習対象ではない)
  4. 共通の正規化処理を呼び出す

チャネルごとに統計量を計算しています。これはCNN(畳み込みニューラルネットワーク)のバッチ正規化でよく使われるアプローチで、各チャネルが独立した特徴を表すという考えに基づいています。

forward_train 内で移動平均と移動分散を計算・更新する際に .detach() メソッドを使用することで、これらの統計情報の更新がモデルパラメータの勾配計算に影響を与えないようにしています。

共通の正規化処理

最後に、訓練時と推論時の両方で使用される共通の正規化処理を見てみましょう:

fn forward_shared<const DI: usize>(
    &self,
    x: Tensor<B, DI>,
    mean: Tensor<B, DI>,
    var: Tensor<B, DI>,
) -> Tensor<B, DI> {
    let channels = x.dims()[1];
    let mut shape = [1; DI];
    shape[1] = channels;

    // 標準偏差を計算(分散の平方根)
    let std = var.add_scalar(self.epsilon).sqrt();

    // 平均を引いて標準偏差で割る(標準化)
    let x = x.sub(mean);
    let x = x.div(std);

    // γをかける(スケーリング)
    let x = x.mul(self.gamma.val().reshape(shape));

    // βを足す(シフト)
    x.add(self.beta.val().reshape(shape))
}

この処理では、バッチ正規化の基本的な数式 $y = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$ を実装しています:

  1. 標準偏差の計算:分散にepsilonを加えてから平方根

    • epsilonはゼロ除算を防ぐ役割もある
  2. 平均を引いて標準偏差で割る:$(x - \mu) / \sqrt{\sigma^2 + \epsilon}$

    • これにより各特徴の分布が平均0、分散1に正規化される
  3. 学習可能なパラメータγで乗算:$\gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}}$

    • γはスケーリング係数で、適切な分散を学習する
  4. 学習可能なパラメータβで加算:$\gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$

    • βはシフト係数で、適切な平均を学習する

γとβのパラメータは、単なる標準化では失われる可能性のある表現力を回復させる重要な役割を持っています。例えば、活性化関数が特定の範囲の入力で最もよく機能する場合、モデルは最適な出力分布になるようγとβを調整できます。

バッチ正規化の訓練と推論の違い

バッチ正規化は訓練時と推論時で異なる振る舞いをします。この違いをまとめます。

訓練時の振る舞い

訓練時には:

  1. 現在のミニバッチの統計量(平均と分散)を使用して正規化
  2. 各バッチごとに統計量が微妙に異なることで正則化効果が生まれる
  3. 同時に、将来の推論用に移動平均と移動分散を更新
  4. 勾配を計算してγとβを更新

推論時の振る舞い

推論時には:

  1. 訓練時に計算された移動平均と移動分散を使用
  2. バッチに依存せず一貫した出力を生成
  3. 小さなバッチやサンプルごとの推論でも安定した結果を提供

この違いにより、バッチ正規化は訓練時には正則化効果を発揮し、推論時には安定した予測をするようになります。Burnは自動微分の状態を確認して自動的に適切なモードを選択できるのがフレームワークとしての特徴です。

実際の使用例

実際のニューラルネットワークでBurnのバッチ正規化を使用する例を見てみましょう。

例えば、CNNで使用する場合:

use burn::module::Module;
use burn::nn::{BatchNormConfig, Conv2dConfig, ReLU};
use burn::tensor::backend::Backend;

// CNNのブロック(畳み込み+バッチ正規化+活性化関数)
pub struct ConvBlock<B: Backend> {
    conv: Conv2d<B>,         // 畳み込み層
    batch_norm: BatchNorm<B, 2>,  // 2次元のバッチ正規化
    activation: ReLU,        // 活性化関数
}

impl<B: Backend> ConvBlock<B> {
    pub fn new(
        in_channels: usize,
        out_channels: usize,
        device: &B::Device,
    ) -> Self {
        // 畳み込み層の設定
        let conv_config = Conv2dConfig::new(in_channels, out_channels)
            .with_kernel_size(3)
            .with_padding(1);
        
        // バッチ正規化の設定
        let bn_config = BatchNormConfig::new(out_channels);
        
        Self {
            conv: conv_config.init(device),
            batch_norm: bn_config.init(device),
            activation: ReLU::new(),
        }
    }
    
    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
        let x = self.conv.forward(x);
        let x = self.batch_norm.forward(x);
        self.activation.forward(x)
    }
}

このように、畳み込み層の後にバッチ正規化を適用し、その後に活性化関数を適用するパターンは一般的です。この順序には重要な理由があります:

  1. 畳み込み層(線形変換)は出力が広い範囲に分布する可能性がある
  2. バッチ正規化で出力を適切な範囲に調整する
  3. 正規化された値に対して活性化関数を適用することで効率的な学習が可能になる

特に深いネットワークでは、「畳み込み→バッチ正規化→活性化」のパターンが各層で繰り返されることで、勾配消失問題を軽減し、学習を安定させます。

なぜバッチ内の統計量で性能が改善するのか? ― 依存性と安定化のトレードオフ

バッチ正規化の中心的なアイデアは、ミニバッチ内のサンプル全体の統計量(平均と分散)を使って各サンプルを正規化することです。ここで素朴な疑問が生じます。「ニューラルネットワークの学習では、各サンプルは独立に処理されるべきではないのか?」「なぜ、あるサンプルの処理が、同じバッチ内の他のサンプルに依存することで性能が改善するのか?

この疑問はもっともですが、バッチ正規化の有効性は、まさにこの「ミニバッチ内の依存性」を巧みに利用している点にあります。

1. 内部共変量シフトへの現実的な対策

前述の通り、ディープラーニングの学習における課題の一つが「内部共変量シフト」です。つまり、学習が進むにつれて前の層のパラメータが変化し、後続の層への入力分布が絶えず変動してしまう問題です。理想的には、各層への入力分布が学習中に安定していることが望ましいです。

データセット全体の真の平均と分散を計算して正規化できれば理想的ですが、これは通常、計算コストが高すぎて現実的ではありません(特にデータセットが大きい場合)。

そこでバッチ正規化では、現在処理しているミニバッチの統計量を、データセット全体の統計量の「推定量」として利用します。ミニバッチはデータセットからランダムにサンプリングされているため、バッチサイズがある程度大きければ、その統計量は真の統計量をそれなりに良く近似すると期待できます。

このミニバッチ統計量を使って強制的に入力分布を(近似的に)平均0、分散1に近づけることで、層への入力分布の変動を大幅に抑制します。結果として、後続の層はより安定した入力に対して学習を進めることができ、学習が効率化・安定化します。

2. 「依存性」がもたらす副次的効果

ミニバッチの統計量を使うこと、つまりバッチ内のサンプル間で依存性を持たせることには、さらにいくつかの利点があります。

  • 勾配の流れの促進: 正規化によって値が適切な範囲に収まるため、活性化関数(特にシグモイドやtanhなど)が飽和しにくくなり、勾配消失問題が軽減されます。これにより、深いネットワークでも学習が進みやすくなります。

  • 意図せぬ正則化: 各ミニバッチの統計量は、真の統計量に対してわずかな「ノイズ」を含みます。同じ入力データでも、どのバッチに含まれるかによって正規化に使われる平均・分散が微妙に異なります。このランダムな変動が、モデルが訓練データに過剰適合するのを防ぐ一種の正則化効果を生み出します。これは、各サンプルを完全に独立に扱う場合には得られない効果です。

3. 訓練時と推論時の使い分け

重要なのは、この「バッチ依存性」は訓練時にのみ意図的に導入されるものであるという点です。訓練時には、内部共変量シフトの抑制と正則化効果を狙ってミニバッチ統計量を使用します。

一方、推論(予測)時には、予測結果がどのバッチに含まれるかに依存しては困ります。そのため、推論時には訓練中に計算・蓄積しておいた移動平均と移動分散(データセット全体の統計量のより安定した推定量)を使って正規化を行います。これにより、推論時には各サンプルは独立に処理され、入力に対して決定的な出力が得られます。

バッチへの依存性は単なるノイズではない

結論として、バッチ正規化がバッチ内の統計量を利用するのは、内部共変量シフトという現実的な問題に対する効率的な近似解法を提供するためです。その過程で生じる「サンプル間の依存性」は、学習の安定化や正則化といった有益な副次効果をもたらします。訓練時と推論時で挙動を変えることで、この依存性の利点を享受しつつ、推論時の一貫性を担保しているのです。バッチ内の依存性は、バッチ正規化という手法がうまく機能するための重要な 「特徴」 と言えるでしょう。

バッチ正規化の利点と注意点

利点

  1. 学習の高速化: より高い学習率を使用できるため、収束が早くなります
  2. 初期化依存性の低減: 重みの初期値の選択に対する敏感さが減ります
  3. 正則化効果: ドロップアウトと同様、過学習を抑制する効果があります
  4. 深いネットワークの学習安定化: 勾配消失問題を軽減します
  5. コスト削減: 適切なバッチ正規化により、より少ないパラメータ数で良好な性能を達成できる場合があります

注意点

  1. バッチサイズの依存性: 小さすぎるバッチサイズでは統計量の推定が不安定になる可能性があります(一般的に16以上推奨)
  2. 計算コスト: 追加の計算とメモリが必要になります
  3. ドロップアウトとの組み合わせ: 一般的にはバッチ正規化の後にドロップアウトを適用するのが良いとされています
  4. 再現性の問題: バッチに依存するため、同じ入力でもバッチサイズや構成によって出力が変わる可能性があります
  5. RNNでの使用: リカレントネットワークでは、シーケンス長が異なる場合に適用が難しいことがあります(LayerNormalizationなどの代替手法がよく使用されます)

まとめ

本記事では、Rustのディープラーニングフレームワーク「Burn」のソースコードを通じて、バッチ正規化の仕組みとその実装方法を解説しました。バッチ正規化は以下の特徴を持つ重要な技術です:

  1. ニューラルネットワークの学習を安定化・高速化する
  2. 訓練時と推論時で異なる挙動をする
  3. 学習可能なパラメータ(γとβ)を持ち、モデルの表現力を向上させる
  4. 内部共変量シフトを軽減し、勾配の流れを改善する

Burnの実装では、Rustの型システムと所有権モデルを活かし、安全かつ効率的にバッチ正規化を実現しています。ジェネリックプログラミングを活用して様々な次元のテンソルに対応できる点や、訓練時と推論時の挙動の違いを明確に分離しています。

参考文献

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?