0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【Rust】Burnのソースから学ぶ「自動微分」

Last updated at Posted at 2025-05-03

なぜ「自動微分」が必要なのか?

AIの学習とは「調整」の繰り返し

AIが学習するとは、簡単に言えば「設定値(パラメータ)を少しずつ調整して、正解に近づけていく作業」です。例えば、猫の画像を判別するAIを考えてみましょう:

  • 「この線が斜めだと猫っぽいから、この設定値を少し増やそう」
  • 「この色合いは犬っぽいから、あの設定値は減らそう」

このような微調整を何百万回も繰り返すことで、AIは徐々に賢くなっていきます。

「どの方向に」「どれくらい」調整するかが重要

問題は、「どの設定値を」「どの方向に」「どれくらい」調整すれば効率良く学習できるのか、という点です。この「方向と量のヒント」のことを 「勾配(こうばい)」 と呼びます。

坂道をイメージすると分かりやすいでしょう。勾配が分かれば、最も効率よく坂を下る方向(=正解に近づく方向)が分かります。この坂を下る方法を「勾配降下法」と呼び、これがディープラーニングの核心です。

手計算では不可能な量の計算

現代のAIモデルは何百万、何千万もの設定値(パラメータ)を持っています。これらすべての勾配を人間が手計算するのは不可能です。そこで登場するのが「自動微分」という技術です。コンピュータが自動的に、しかも効率的に勾配を計算してくれます。

自動微分の仕組みを理解する

料理のレシピで例えると...

自動微分の仕組みを料理に例えて考えてみましょう:

  1. 計算の手順を記録(フォワードパス):
    カレーを作るとき、「玉ねぎを切る」「肉を炒める」「水を入れる」「ルーを溶かす」といった手順を細かくメモしながら料理するとします。各工程と、その時点での鍋の中身(=計算途中の値)をすべて記録します。

  2. 勾配を計算(リバースパス):
    できあがったカレーを食べて「もっと辛くしたい」と思ったとします(=AIの予測を改善したい)。自動微分は、記録したレシピを逆からたどって、「辛さを増すには、どの工程で入れたスパイスをどれくらい増やせばいいか」を教えてくれます。

実際の計算例でもう少し具体的に

数式で見ると:y = x1 * x2 + sin(x1) という計算があるとします。x1=2x2=3 のとき、x1x2が少し変わったら最終結果yがどれくらい変わるか(=勾配)を知りたいケースです。

  1. フォワードパス(計算と記録):

    • v1 = x1 (= 2)
    • v2 = x2 (= 3)
    • v3 = v1 * v2 (= 6)
    • v4 = sin(v1) (≈ 0.909)
    • y = v3 + v4 (≈ 6.909)

    この計算の流れと途中の値をすべて記録します。

  2. リバースパス(勾配計算):
    最終結果yから逆向きに、各計算がどれくらい結果に影響したかを計算していきます。最終的に「x1が少し変わると結果は2.584倍ほど変わる」「x2が少し変わると結果は2倍変わる」という勾配情報が得られます。

これが自動微分の基本的な考え方です。複雑な計算でも、基本的な足し算や掛け算などの勾配さえ分かっていれば、連鎖律(複合関数の微分法則)を使って全体の勾配を自動で計算できるのです。

BurnはRustでどのように自動微分を実現しているのか

2種類の計算エンジン

Burnフレームワークでは、計算を実行する部品(バックエンド)が2種類あります:

  1. Backend: 基本的な計算(足し算、掛け算など)を行うエンジン
  2. AutodiffBackend: 上記の機能に加えて、自動微分(勾配計算)の機能も提供するエンジン

勾配が必要な学習時などには特別なAutodiffBackendを使う必要があります。

デコレータ:既存のエンジンに自動微分機能を「着せる」

Burnのスマートな点は、「デコレータ」という仕組みで既存の計算エンジンに後から自動微分機能を追加できることです:

// 例:通常のGPUバックエンド
type MyBackend = Wgpu<f32, i32>;

// 自動微分機能を追加する
type MyAutodiffBackend = Autodiff<MyBackend>;

Autodiff<...>でラップするだけで、通常のエンジンが自動微分に対応するようになります。

計算グラフの記録と勾配の計算

AutodiffBackendを使って計算を行うと、Burnは内部的に「どの値がどの計算で作られたか」という情報を記録します。そして勾配が必要になったとき、.backward()メソッドを呼び出すことで勾配計算が実行されます:

// loss は最終的な損失値(これに対する勾配を計算したい)
let gradients: Gradients = loss.backward(); // リバースパスを実行!

この.backward()が呼ばれると、先ほどの料理の例のように計算手順を逆向きにたどり、勾配が計算されるのです。

勾配の管理方法

Burnでは、計算された勾配はGradientsという専用の入れ物(コンテナ)に格納されます。勾配の保存と取得を担う重要な部分を見てみましょう:

/// Gradients container used during the backward pass.
pub struct Gradients {
    container: TensorContainer<GradID>,
}

impl Gradients {
    /// Removes a grad tensor from the container.
    pub fn remove<B: Backend>(&mut self, tensor: &AutodiffTensor<B>) -> Option<FloatTensor<B>> {
        self.container
            .remove::<B>(&tensor.node.id.value)
            .map(|tensor| tensor.tensor())
    }

    /// Gets a grad tensor from the container.
    pub fn get<B: Backend>(&self, tensor: &AutodiffTensor<B>) -> Option<FloatTensor<B>> {
        self.container
            .get::<B>(&tensor.node.id.value)
            .map(|tensor| tensor.tensor())
    }

    /// Register a grad tensor in the container.
    ///
    /// If the tensor already exists, add both tensors together before saving the result.
    pub fn register<B: Backend>(&mut self, node_id: NodeID, value: FloatTensor<B>) {
        if let Some(tensor_old) = self.container.remove::<B>(&node_id.value) {
            self.container.register::<B>(
                node_id.value,
                burn_tensor::TensorPrimitive::Float(B::float_add(value, tensor_old.tensor())),
            );
        } else {
            self.container
                .register::<B>(node_id.value, burn_tensor::TensorPrimitive::Float(value));
        }
    }
}

この実装からわかるように、Gradients構造体は内部にTensorContainerを持ち、各ノードID(GradID型)と対応する勾配テンソルを関連付けています。勾配を追加する際には、同じノードIDの勾配が既に存在する場合は加算されるという仕組みになっています。

使用する際は下記のようになります:

// backward()で勾配コンテナを取得
let mut gradients: Gradients = loss.backward();

// 特定のテンソルの勾配を取得
let grad_of_tensor = tensor.grad(&gradients);

// または勾配を取り出しつつコンテナから削除(効率的)
let grad_of_tensor = tensor.grad_remove(&mut gradients);

このコンテナが、計算に関わった各値(特に学習対象のパラメータ)に対応する勾配をまとめて管理しています。テンソルクラスにはメソッドとしてdetach()(計算グラフから切り離す)、require_grad()(勾配計算の対象にする)、is_require_grad()(勾配計算対象か確認)などが実装されており、自動微分の制御が可能になっています。

Burnの最近のリリースでは、メモリ管理の方法が大幅に改善され、テンソルのメモリ再利用が効率化されました。特にCPUでの実行時にはPyTorchと比較してメモリ使用量が少なくなっています。

PythonのPyTorchとの違い

APIの違い

PythonのPyTorchとBurnでは、メソッド名などが少し異なりますが、基本的な操作は似ています:

Burn API PyTorch 相当
tensor.detach() tensor.detach()
tensor.require_grad() tensor.requires_grad()
tensor.is_require_grad() tensor.requires_grad
tensor.set_require_grad(require_grad) tensor.requires_grad(False)

最大の違い:勾配の扱い方

これが最も大きな違いです:

  • PyTorch: .backward()を呼ぶと、各パラメータの.grad属性に直接勾配が書き込まれる
  • Burn: .backward()を呼ぶと、勾配が入ったGradientsコンテナが返り値として得られる

この違いは、Rustの「所有権システム」と関係しています。Burnの公式ドキュメントには、以下のように説明されています:

Burnの勾配処理方法はPyTorchとは異なります。backward()を呼び出しても各パラメータのgrad属性が更新されるのではなく、計算されたすべての勾配がコンテナに格納されて返されます。このアプローチには、勾配を他のスレッドに簡単に送信できるなど、多くの利点があります。

Burnの方式では:

  1. 勾配がどこにあるかが明確で、使い忘れを防ぎやすい
  2. 並列処理との相性が良く、異なるGPU間でのデータ移動が安全
  3. 勾配の加工や複数の勾配をまとめる操作がしやすい

Rustの所有権モデルを活かした設計により、Burnは内部的にロックの数を減らし、性能を向上させています。特に小さなモデルでの処理速度が向上しているとのことです。

推論時(学習以外)の違い

学習以外のとき(テストや実運用時)には勾配計算は不要で、むしろ無駄な計算になります:

  • PyTorch: with torch.no_grad(): というブロックで囲んで一時的に勾配計算をオフにする
  • Burn: 最初から自動微分機能がない通常のBackendを使えば良い

Burnのソースコードを見ると、この区別が型レベルで明確に表現されています:

// 学習時など:AutodiffBackend を使う
fn train_step<B: AutodiffBackend>(model: &Model<B>, input: Tensor<B, N>) { ... }

// 推論時:通常の Backend を使う
fn inference<B: Backend>(model: &Model<B>, input: Tensor<B, N>) { ... }

また、Burnの公式ドキュメントには、自動微分を使わない場合でもinner()メソッドで内部テンソルを取得する方法が紹介されています:

fn example_validation<B: AutodiffBackend>(tensor: Tensor<B, 2>) {
    // 内部の通常テンソルを取り出す(勾配計算対象外)
    let inner_tensor: Tensor<B::InnerBackend, 2> = tensor.inner();
    let _ = inner_tensor + 5;
}

Burnのアプローチは、Rustの型システムによってコンパイル時に自動微分を使うかどうかが決まるため、実行時にモードを切り替えるPyTorchよりも安全という利点があります。

自動微分を使った実際の学習例

モデル学習の基本的な流れ

AIモデルの一般的な学習ループは次のようになります:

// 1. モデルで予測を計算(フォワードパス、計算グラフが記録される)
let output = model.forward(input);

// 2. 予測と正解から損失(誤差)を計算
let loss = calculate_loss(output, target);

// 3. 損失に対する勾配を計算(リバースパス)
let gradients: Gradients = loss.backward();

// 4. 勾配を使ってモデルのパラメータを更新
optimizer.update(&mut model, &gradients);

backward()で得られた勾配コンテナをoptimizer.update()に渡して、モデルのパラメータを更新するのがポイントです。

複数GPUでの並列学習

Burnの勾配コンテナ方式は、複数のGPUを使った分散学習を安全かつシンプルに実現できます:

use std::thread;

// 各GPUでモデルのコピーを使って計算
let model_gpu1 = model.clone().to_device(Device::Gpu(0));
let model_gpu2 = model.clone().to_device(Device::Gpu(1));
// ... 入力データも各GPUへ

let handle1 = thread::spawn(move || {
    let output = model_gpu1.forward(input_gpu1);
    let loss = calculate_loss(output, target_gpu1);
    loss.backward() // GPU1での勾配コンテナを返す
});

let handle2 = thread::spawn(move || {
    let output = model_gpu2.forward(input_gpu2);
    let loss = calculate_loss(output, target_gpu2);
    loss.backward() // GPU2での勾配コンテナを返す
});

// 各スレッドから返ってきた勾配コンテナを集約
let grads1 = handle1.join().unwrap();
let grads2 = handle2.join().unwrap();
let combined_grads = combine_gradients(grads1, grads2);

// 集約した勾配で元のモデルを更新
optimizer.update(&mut model, &combined_grads);

Burnの公式リポジトリによると、「Burnはスレッドセーフ性をRustの所有権システムを活用して強調しています。各モジュールは自身の重みの所有者であるため、モジュールを別スレッドに送って勾配を計算し、その勾配をメインスレッドに送って集約することが可能です」とのことです。これはPyTorchの方法とは大きく異なり、Rustの所有権のおかげでデータ競合を心配せずに並列処理を行うことができます。

カスタム関数の微分

自分で定義した関数も、Burnのテンソル操作を使って書いていれば自動的に微分可能になります:

// カスタムGELU活性化関数
fn gelu_custom<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
    let x_clone = x.clone();
    let intermediate = (x / SQRT_2).erf() + 1.0;
    let result = x_clone * intermediate / 2.0;
    result
}

この関数内で使われている足し算、掛け算、割り算、誤差関数(erf)などの操作はすべて微分可能なので、関数全体も自動的に微分可能になります。特別な設定をしなくても、自作の関数の勾配も計算できるのです。

まとめ

Burnの自動微分機能について学ぶことでAIの学習プロセスの理解が深まりました。

重要なポイント:

  1. 自動微分はAI学習の核心技術: AIモデルのパラメータをどう調整すべきかの「ヒント」を効率的に計算してくれる
  2. 計算グラフの記録と逆再生: 計算手順を記録し、逆向きにたどることで勾配を効率的に計算
  3. BurnのRust的アプローチ: 型システムと所有権を活かした安全で効率的な設計
  4. 実用的な応用: 分散学習やカスタム関数の自動微分など、幅広い応用が可能

補足

Burnを使った様々なディープラーニングモデル(LLaMA、CNN、BERT、YOLOなど)が実装されており、GitHub上のmodelsリポジトリで確認できます。Burnが実際のAIモデル構築にどう活かされているかを見る良い例となっています。

参考リンク

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?