0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【Rust】BurnでMNISTデータセットを扱う

Posted at

Rustの機械学習フレームワークBurnで、「データローダー」の実装方法について紹介します。入門データセットとして広く使われているMNIST(手書き数字)データセットを扱う方法を解説します。

burnフレームワークとは?

BurnはRustのネイティブ機械学習フレームワークで、モダンなAPIと高いパフォーマンスを兼ね備えています。PyTorchなどのPythonベースのフレームワークに親しんだ開発者でも使いやすい設計がなされていながら、Rustの強力な型システムの恩恵を受けられるのが大きな特徴です。

Burnの特徴:

  • 強力な型安全性と完全性チェック
  • CPUとGPUの両方でのトレーニングをサポート
  • Pythonとの相互運用性
  • 豊富な機械学習アルゴリズムとレイヤーの実装

MNISTデータセットとは?

MNIST(Modified National Institute of Standards and Technology)データセットは、機械学習、特に画像認識の入門用データセットとして広く使われています。0から9までの手書き数字の画像28×28ピクセルの画像から構成され、トレーニング用に60,000枚、テスト用に10,000枚のサンプルが含まれています。

MNISTデータセット例

機械学習モデルにとっては、このデータセットを効率的に読み込み、処理できるデータローダーが不可欠です。

データローダーのコード解説

それでは、BurnフレームワークでMNISTデータセットを扱うためのデータローダーの実装を見ていきましょう。以下のコードは、MNISTデータをバッチ処理するためのコアとなるコンポーネントです。

use burn::{
    data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
    prelude::*,
};

#[derive(Clone, Default)]
pub struct MnistBatcher {}

#[derive(Clone, Debug)]
pub struct MnistBatch<B: Backend> {
    pub images: Tensor<B, 3>,
    pub targets: Tensor<B, 1, Int>,
}

impl<B: Backend> Batcher<B, MnistItem, MnistBatch<B>> for MnistBatcher {
    fn batch(&self, items: Vec<MnistItem>, device: &B::Device) -> MnistBatch<B> {
        let images = items
            .iter()
            .map(|item| TensorData::from(item.image).convert::<B::FloatElem>())
            .map(|data| Tensor::<B, 2>::from_data(data, device))
            .map(|tensor| tensor.reshape([1, 28, 28]))
            .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081)
            .collect();

        let targets = items
            .iter()
            .map(|item| {
                Tensor::<B, 1, Int>::from_data([(item.label as i64).elem::<B::IntElem>()], device)
            })
            .collect();

        let images = Tensor::cat(images, 0);
        let targets = Tensor::cat(targets, 0);

        MnistBatch { images, targets }
    }
}

コードの構成要素と役割

1. インポート

use burn::{
    data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
    prelude::*,
};

ここでは、必要なモジュールをインポートしています:

  • Batcher: データをバッチ単位で処理するためのトレイト
  • MnistItem: MNISTデータセットの各アイテム(画像とラベル)を表す構造体
  • prelude::*: burnフレームワークの標準的な機能をまとめてインポート

2. MnistBatcher構造体

#[derive(Clone, Default)]
pub struct MnistBatcher {}

この構造体は空ですが、Batcherトレイトを実装することで、MNISTデータをバッチ処理する機能を提供します。CloneDefaultトレイトの自動導出により、この構造体のインスタンスをクローンしたりデフォルト値で作成したりできます。

空の構造体でも機能を提供できる理由は、Rustではトレイト実装がデータと振る舞いを分離できるためです。今回の場合、バッチャー自体は状態を持たないので、空の構造体で十分なのです。

3. MnistBatch構造体

#[derive(Clone, Debug)]
pub struct MnistBatch<B: Backend> {
    pub images: Tensor<B, 3>,
    pub targets: Tensor<B, 1, Int>,
}

この構造体は、バッチ処理された後のMNISTデータを保持します:

  • ジェネリック型Bは、burnフレームワークのバックエンド(CPU、CUDA、WebGPUなど)を表します
  • images: 3次元テンソル([バッチサイズ, 高さ, 幅])で、画像データを保持
  • targets: 1次元の整数テンソルで、各画像のラベル(0-9の数字)を保持

テンソルというのは、機械学習でよく使われる多次元配列のことで、ここでは画像データや数値ラベルを効率的に処理するために使われています。

4. Batcherトレイトの実装

impl<B: Backend> Batcher<B, MnistItem, MnistBatch<B>> for MnistBatcher {
    fn batch(&self, items: Vec<MnistItem>, device: &B::Device) -> MnistBatch<B> {
        // 実装部分(後述)
    }
}

Batcherトレイトを実装することで、MnistBatcherに「バッチ作成能力」を与えています。ジェネリクスを使って、さまざまなバックエンド(B)でこの機能を使えるようにしています。

  • Batcher<B, MnistItem, MnistBatch<B>>: 入力としてMnistItemのコレクションを受け取り、出力としてMnistBatch<B>を生成するバッチャーを定義
  • batchメソッド: 個々のMnistItemをまとめて、1つのバッチにする処理を実装

データ処理の詳細解説

画像データの処理

let images = items
    .iter()
    .map(|item| TensorData::from(item.image).convert::<B::FloatElem>())
    .map(|data| Tensor::<B, 2>::from_data(data, device))
    .map(|tensor| tensor.reshape([1, 28, 28]))
    .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081)
    .collect();

画像データの処理は、関数型プログラミングのパイプラインとして実装されています。各ステップを順に解説します:

  1. データ型の変換:

    .map(|item| TensorData::from(item.image).convert::<B::FloatElem>())
    

    MNISTデータの画像(通常は0-255の整数値)を、バックエンド固有の浮動小数点型に変換します。これにより、後続の数値計算が正確に行えるようになります。

  2. テンソルの作成:

    .map(|data| Tensor::<B, 2>::from_data(data, device))
    

    変換されたデータを、指定されたデバイス(CPUやGPU)上に2次元テンソルとして配置します。

  3. 形状の変更:

    .map(|tensor| tensor.reshape([1, 28, 28]))
    

    テンソルの形状を[1, 28, 28]に変更します。ここでの「1」はチャネル数を表し、モノクロ画像であるMNISTの場合は1つのチャネルとなります。28×28はMNIST画像のピクセルサイズです。

  4. 正規化:

    .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081)
    

    画像データを正規化します:

    • まず255で割って0-1の範囲にスケーリング
    • 次にMNISTデータセットの平均値0.1307を引く
    • 最後にMNISTデータセットの標準偏差0.3081で割る

    この正規化により、機械学習モデルの学習が安定し、収束が早くなります。

  5. 収集:

    .collect()
    

    処理された画像テンソルをベクタ型(Vec)に集めます。

ラベルデータの処理

let targets = items
    .iter()
    .map(|item| {
        Tensor::<B, 1, Int>::from_data([(item.label as i64).elem::<B::IntElem>()], device)
    })
    .collect();

ラベルデータの処理も似たようなパイプラインで行われます:

  1. 各MNISTアイテムからラベル(0-9の数字)を取得
  2. ラベルをi64型に変換した後、バックエンド固有の整数型に変換
  3. 1次元のテンソルとしてデバイス上に配置
  4. 処理されたラベルテンソルをベクタ型に集める

テンソルの連結とバッチの作成

let images = Tensor::cat(images, 0);
let targets = Tensor::cat(targets, 0);

MnistBatch { images, targets }

最後に、個々の画像とラベルのテンソルを連結し、1つのバッチを作成します:

  1. Tensor::cat(images, 0): 画像テンソルを0次元(バッチ次元)に沿って連結
  2. Tensor::cat(targets, 0): ラベルテンソルを同様に連結
  3. 連結されたテンソルからMnistBatch構造体のインスタンスを作成して返す

この処理により、複数のMNISTアイテムが1つのバッチにまとめられ、効率的にモデルに供給できるようになります。

データローダーの使い方

上記で実装したデータローダーは、次のように使用できます:

// MNISTデータセットの読み込み
let dataset = MnistDataset::new(
    "path/to/mnist/train-images.idx3-ubyte",
    "path/to/mnist/train-labels.idx1-ubyte",
);

// データローダーの作成(バッチサイズ=32)
let batcher = MnistBatcher::default();
let dataloader = DataLoaderBuilder::new(batcher)
    .batch_size(32)
    .shuffle(true)
    .num_workers(2)
    .build(dataset);

// データローダーを使ったトレーニングループ
for batch in dataloader.iter() {
    let output = model.forward(batch.images);
    let loss = cross_entropy_loss(output, batch.targets);
    // 勾配計算と最適化ステップ...
}

データローダーの重要性

なぜデータローダーが必要なのでしょうか?それには下記の理由があります:

  1. 効率性: 大規模なデータセットを小さなバッチに分割して処理することで、メモリ使用量を抑えつつ効率的に学習できます。

  2. 前処理: データローダーは、正規化やリサイズなどの前処理を一元管理できます。これにより、一貫した前処理が保証され、コードの重複も避けられます。

  3. 並列処理: 複数のワーカーを使って並列にデータを読み込むことで、I/Oがボトルネックになるのを防ぎます。

  4. バックエンド抽象化: burnのデータローダーは、異なるバックエンド(CPU/GPU)に対して同じインターフェイスを提供します。これにより、コードを大幅に変更することなく、異なるハードウェアでモデルをトレーニングできます。

発展的なトピック

カスタムデータ拡張(Data Augmentation)

実際のプロジェクトでは、データ拡張を使って訓練データのバリエーションを増やすことがよくあります。burnでもデータローダーを拡張して、回転やフリップなどの拡張を行うことができます:

// 画像の前処理に拡張を追加する例
.map(|tensor| {
    if rand::random::<f32>() > 0.5 {
        // 50%の確率で画像を水平方向に反転
        tensor.flip(2)
    } else {
        tensor
    }
})

まとめ

つぎはトレーニングについて書いてみたいと思います

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?