1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Rust製MLフレームワーク「Burn」

Posted at

機械学習フレームワークといえば、Python製のPyTorchやTensorFlow、JAXが有名ですが、実はRust言語でも高性能なフレームワークの開発が活発に進んでいます。中でも「Burn」は、柔軟性、効率性、そして移植性の高さを追求した、次世代のディープラーニングフレームワークとして注目されている存在です。

この記事では、最近リリースされた「Burn v0.17.0」の新機能や特徴を解説し、Rustで機械学習を始めてみたい方に向けて、Burnの基本的な魅力をご紹介します。

Burnってどんなフレームワーク?

BurnはRustで書かれたディープラーニングフレームワークで、こんな特徴を持っています:

  • マルチバックエンド対応: GPU (NVIDIA, AMD, Apple Silicon) やCPUなど、様々なハードウェアに対応。
  • 高性能: 計算処理が速く、自動微分もサポート。
  • 最適化: 自動カーネル融合などで処理を高速化。
  • WebAssembly対応: ブラウザ上でも実行可能。
  • 組み込み向け: no_std環境でも動作し、組み込みデバイスにも載せられます。

多くのフレームワークがPythonインターフェースとC/C++バックエンドを組み合わせる「2言語アプローチ」を取るのに対し、BurnはRustの強力な抽象化能力を活かし、ハイレベルなAPIと高性能な実行をRustだけで実現しています。

Burn v0.17.0 の主な新機能

v0.17.0では、さらに使いやすく、パワフルになりました。主な新機能を見てみましょう。

  1. Metalバックエンドの追加
    WGPUパススルーを利用した新しいMetalバックエンドが登場! これにより、AppleのMacやiOSデバイス上でGPUを使った高速計算が可能になりました。

    // Cargo.toml に追加
    burn = { version = "0.17.0", features = ["metal"] }
    
    // コード例
    use burn::prelude::*;
    use burn::backend::wgpu::{Metal, WgpuDevice};
    
    let device = WgpuDevice::default(); // デフォルトのMetalデバイスを選択
    let tensor = Tensor::<Metal, 2>::zeros([2, 4], &device);
    
  2. CubeCLの強化
    内部エンジンのCubeCLが進化し、Cuda、Metal、Rocm、Vulkan、WebGpuといった主要なバックエンドを幅広くサポートするようになりました。

  3. テンソル操作の融合サポートが大幅強化
    要素ごとの演算や縮約、行列乗算などを賢く一つにまとめる「テンソル操作の融合」機能がパワーアップ。より効率的な計算が期待できます。

  4. コンパイル&オートチューンキャッシュ
    一度コンパイルしたバイナリや、最適化されたカーネル設定をキャッシュしておくことで、同じ処理を繰り返す際の実行速度が向上しました。

  5. データ並列トレーニングの改善
    複数のGPUを使ったデータ並列トレーニングがより簡単に。各GPU(ワーカー)へのバッチ割り当てが自動化され、効率よく学習を進められます。

  6. 新しいテンソルスライスAPI
    テンソルの一部を切り出す(スライスする)ためのAPIが、より直感的でRustらしい構文になりました。

    // 以前の書き方
    let slice = tensor.slice([(0, -1), (0, -2)]);
    
    // 新しい書き方 (RustのRange構文)
    let slice = tensor.slice([0..-1, 0..-2]); // 最後の要素を除く範囲、最後の2要素を除く範囲
    
    // より複雑な指定には s![] マクロも便利
    let t = 5;
    let slice = tensor.slice(s![.., t..t + 1, ..]); // 全範囲, t番目のみ, 全範囲
    
  7. 量子化行列乗算の初期実装
    モデルの軽量化に繋がる、量子化された行列乗算の機能が試験的に導入されました。

Burnのパフォーマンスを支える技術

Burnの速さの秘密は、いくつかのスマートな技術にあります。

  • 自動カーネル融合
    Burnは、複数のテンソル操作を自動的に一つの最適化された低レベルカーネル(GPUなどで実行されるプログラム)にまとめ上げます。例えば、以下のようなGELU活性化関数を書くと…

    fn gelu_custom<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
        let x = x.clone() * ((x / SQRT_2).erf() + 1);
        x / 2
    }
    

    実行時には、60行ほどのWGSLシェーダーコードが自動生成され、手書きのGPU実装に匹敵するパフォーマンスを発揮するとのことです。

  • 非同期実行
    一部のバックエンドでは非同期実行を採用。これにより、フレームワーク自体の処理(オーバーヘッド)がモデル計算の邪魔をしないように工夫されています。

  • スレッドセーフな設計
    Rustの「所有権」システムを活かし、各モジュールが自身の重みを管理。これにより、勾配計算を別スレッドで行い、結果をメインスレッドで集約する、といった並列処理も容易に行えます。

  • インテリジェントなメモリ管理
    テンソルのメモリ割り当て・解放を効率化。メモリプールを使ったり、Rustの所有権を頼りにテンソルを安全に「その場で変更」できるタイミングを見極めたりすることで、メモリ使用量を削減しています。

  • 自動カーネル選択
    行列乗算のような操作では、行列のサイズやハードウェアによって最適な実行方法(カーネルパラメータ)が変わります。Burnは自動でベンチマークを行い、現在の状況にベストな設定を選んでくれます。

Burnで使えるバックエンド

Burnは様々な環境で動くように設計されており、以下のようなバックエンドを選んだり組み合わせたりできます。

バックエンド 対応デバイス種類 区分
Cuda NVIDIA GPU ファーストパーティ
ROCm AMD GPU ファーストパーティ
Metal Apple GPU ファーストパーティ
Vulkan Linux & Windows の多くのGPU ファーストパーティ
Wgpu 多くのGPU ファーストパーティ
NdArray 多くのCPU サードパーティ
LibTorch 多くのGPU & CPU サードパーティ
Candle Nvidia, Apple GPU & CPU サードパーティ

これらに加えて、以下のような機能を組み合わせることも可能です。

  • Autodiff: 好きなバックエンドに自動微分機能を追加
  • Fusion: 一部のバックエンドにカーネル融合機能を追加
  • Router: 複数のバックエンドを一つにまとめる(ベータ版)
  • Remote: リモート実行用バックエンド(ベータ版)

例えば、自動微分とカーネル融合をWgpuバックエンドで使いたい場合は、このように型定義します。

use burn::backend::{Autodiff, Fusion, Wgpu};

type MyBackend = Autodiff<Fusion<Wgpu>>;

Burnを使ってみる

簡単な例として、ニューラルネットワークの基本的な部品であるPosition-wise Feed-Forward層を定義してみます。

use burn::{
    nn::{self, Gelu, Dropout, Linear}, // 各モジュールをインポート
    module::Module,
    tensor::backend::Backend,
    tensor::Tensor,
};

#[derive(Module, Debug)] // Moduleトレイトを自動導出
pub struct PositionWiseFeedForward<B: Backend> { // バックエンドBをジェネリック型パラメータに
    linear_inner: Linear<B>,
    linear_outer: Linear<B>,
    dropout: Dropout,
    gelu: Gelu,
}

impl<B: Backend> PositionWiseFeedForward<B> {
    // forwardメソッドを実装
    pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
        // バックエンドBに依存しない形で処理を記述
        let x = self.linear_inner.forward(input);
        let x = self.gelu.forward(x);
        let x = self.dropout.forward(x);
        self.linear_outer.forward(x)
    }
}

このように、Burnでは型パラメータとジェネリクスをうまく使うことで、具体的なハードウェア(バックエンド)を意識せずにモデルを記述できます。

まとめ

BurnはRustで書かれた機械学習フレームワークで、しかも新しいアップデートでMetalにも対応した要注目のプロジェクトです!

参考リンク

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?