Rustの機械学習フレームワークBurnで、「データローダー」の実装方法について紹介します。入門データセットとして広く使われているMNIST(手書き数字)データセットを扱う方法を解説します。
burnフレームワークとは?
BurnはRustのネイティブ機械学習フレームワークで、モダンなAPIと高いパフォーマンスを兼ね備えています。PyTorchなどのPythonベースのフレームワークに親しんだ開発者でも使いやすい設計がなされていながら、Rustの強力な型システムの恩恵を受けられるのが大きな特徴です。
Burnの特徴:
- 強力な型安全性と完全性チェック
- CPUとGPUの両方でのトレーニングをサポート
- Pythonとの相互運用性
- 豊富な機械学習アルゴリズムとレイヤーの実装
MNISTデータセットとは?
MNIST(Modified National Institute of Standards and Technology)データセットは、機械学習、特に画像認識の入門用データセットとして広く使われています。0から9までの手書き数字の画像28×28ピクセルの画像から構成され、トレーニング用に60,000枚、テスト用に10,000枚のサンプルが含まれています。
機械学習モデルにとっては、このデータセットを効率的に読み込み、処理できるデータローダーが不可欠です。
データローダーのコード解説
それでは、BurnフレームワークでMNISTデータセットを扱うためのデータローダーの実装を見ていきましょう。以下のコードは、MNISTデータをバッチ処理するためのコアとなるコンポーネントです。
use burn::{
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
prelude::*,
};
#[derive(Clone, Default)]
pub struct MnistBatcher {}
#[derive(Clone, Debug)]
pub struct MnistBatch<B: Backend> {
pub images: Tensor<B, 3>,
pub targets: Tensor<B, 1, Int>,
}
impl<B: Backend> Batcher<B, MnistItem, MnistBatch<B>> for MnistBatcher {
fn batch(&self, items: Vec<MnistItem>, device: &B::Device) -> MnistBatch<B> {
let images = items
.iter()
.map(|item| TensorData::from(item.image).convert::<B::FloatElem>())
.map(|data| Tensor::<B, 2>::from_data(data, device))
.map(|tensor| tensor.reshape([1, 28, 28]))
.map(|tensor| ((tensor / 255) - 0.1307) / 0.3081)
.collect();
let targets = items
.iter()
.map(|item| {
Tensor::<B, 1, Int>::from_data([(item.label as i64).elem::<B::IntElem>()], device)
})
.collect();
let images = Tensor::cat(images, 0);
let targets = Tensor::cat(targets, 0);
MnistBatch { images, targets }
}
}
コードの構成要素と役割
1. インポート
use burn::{
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
prelude::*,
};
ここでは、必要なモジュールをインポートしています:
-
Batcher
: データをバッチ単位で処理するためのトレイト -
MnistItem
: MNISTデータセットの各アイテム(画像とラベル)を表す構造体 -
prelude::*
: burnフレームワークの標準的な機能をまとめてインポート
2. MnistBatcher構造体
#[derive(Clone, Default)]
pub struct MnistBatcher {}
この構造体は空ですが、Batcher
トレイトを実装することで、MNISTデータをバッチ処理する機能を提供します。Clone
とDefault
トレイトの自動導出により、この構造体のインスタンスをクローンしたりデフォルト値で作成したりできます。
空の構造体でも機能を提供できる理由は、Rustではトレイト実装がデータと振る舞いを分離できるためです。今回の場合、バッチャー自体は状態を持たないので、空の構造体で十分なのです。
3. MnistBatch構造体
#[derive(Clone, Debug)]
pub struct MnistBatch<B: Backend> {
pub images: Tensor<B, 3>,
pub targets: Tensor<B, 1, Int>,
}
この構造体は、バッチ処理された後のMNISTデータを保持します:
- ジェネリック型
B
は、burnフレームワークのバックエンド(CPU、CUDA、WebGPUなど)を表します -
images
: 3次元テンソル([バッチサイズ, 高さ, 幅])で、画像データを保持 -
targets
: 1次元の整数テンソルで、各画像のラベル(0-9の数字)を保持
テンソルというのは、機械学習でよく使われる多次元配列のことで、ここでは画像データや数値ラベルを効率的に処理するために使われています。
4. Batcherトレイトの実装
impl<B: Backend> Batcher<B, MnistItem, MnistBatch<B>> for MnistBatcher {
fn batch(&self, items: Vec<MnistItem>, device: &B::Device) -> MnistBatch<B> {
// 実装部分(後述)
}
}
Batcher
トレイトを実装することで、MnistBatcher
に「バッチ作成能力」を与えています。ジェネリクスを使って、さまざまなバックエンド(B)でこの機能を使えるようにしています。
-
Batcher<B, MnistItem, MnistBatch<B>>
: 入力としてMnistItem
のコレクションを受け取り、出力としてMnistBatch<B>
を生成するバッチャーを定義 -
batch
メソッド: 個々のMnistItem
をまとめて、1つのバッチにする処理を実装
データ処理の詳細解説
画像データの処理
let images = items
.iter()
.map(|item| TensorData::from(item.image).convert::<B::FloatElem>())
.map(|data| Tensor::<B, 2>::from_data(data, device))
.map(|tensor| tensor.reshape([1, 28, 28]))
.map(|tensor| ((tensor / 255) - 0.1307) / 0.3081)
.collect();
画像データの処理は、関数型プログラミングのパイプラインとして実装されています。各ステップを順に解説します:
-
データ型の変換:
.map(|item| TensorData::from(item.image).convert::<B::FloatElem>())
MNISTデータの画像(通常は0-255の整数値)を、バックエンド固有の浮動小数点型に変換します。これにより、後続の数値計算が正確に行えるようになります。
-
テンソルの作成:
.map(|data| Tensor::<B, 2>::from_data(data, device))
変換されたデータを、指定されたデバイス(CPUやGPU)上に2次元テンソルとして配置します。
-
形状の変更:
.map(|tensor| tensor.reshape([1, 28, 28]))
テンソルの形状を[1, 28, 28]に変更します。ここでの「1」はチャネル数を表し、モノクロ画像であるMNISTの場合は1つのチャネルとなります。28×28はMNIST画像のピクセルサイズです。
-
正規化:
.map(|tensor| ((tensor / 255) - 0.1307) / 0.3081)
画像データを正規化します:
- まず255で割って0-1の範囲にスケーリング
- 次にMNISTデータセットの平均値0.1307を引く
- 最後にMNISTデータセットの標準偏差0.3081で割る
この正規化により、機械学習モデルの学習が安定し、収束が早くなります。
-
収集:
.collect()
処理された画像テンソルをベクタ型(Vec)に集めます。
ラベルデータの処理
let targets = items
.iter()
.map(|item| {
Tensor::<B, 1, Int>::from_data([(item.label as i64).elem::<B::IntElem>()], device)
})
.collect();
ラベルデータの処理も似たようなパイプラインで行われます:
- 各MNISTアイテムからラベル(0-9の数字)を取得
- ラベルをi64型に変換した後、バックエンド固有の整数型に変換
- 1次元のテンソルとしてデバイス上に配置
- 処理されたラベルテンソルをベクタ型に集める
テンソルの連結とバッチの作成
let images = Tensor::cat(images, 0);
let targets = Tensor::cat(targets, 0);
MnistBatch { images, targets }
最後に、個々の画像とラベルのテンソルを連結し、1つのバッチを作成します:
-
Tensor::cat(images, 0)
: 画像テンソルを0次元(バッチ次元)に沿って連結 -
Tensor::cat(targets, 0)
: ラベルテンソルを同様に連結 - 連結されたテンソルから
MnistBatch
構造体のインスタンスを作成して返す
この処理により、複数のMNISTアイテムが1つのバッチにまとめられ、効率的にモデルに供給できるようになります。
データローダーの使い方
上記で実装したデータローダーは、次のように使用できます:
// MNISTデータセットの読み込み
let dataset = MnistDataset::new(
"path/to/mnist/train-images.idx3-ubyte",
"path/to/mnist/train-labels.idx1-ubyte",
);
// データローダーの作成(バッチサイズ=32)
let batcher = MnistBatcher::default();
let dataloader = DataLoaderBuilder::new(batcher)
.batch_size(32)
.shuffle(true)
.num_workers(2)
.build(dataset);
// データローダーを使ったトレーニングループ
for batch in dataloader.iter() {
let output = model.forward(batch.images);
let loss = cross_entropy_loss(output, batch.targets);
// 勾配計算と最適化ステップ...
}
データローダーの重要性
なぜデータローダーが必要なのでしょうか?それには下記の理由があります:
-
効率性: 大規模なデータセットを小さなバッチに分割して処理することで、メモリ使用量を抑えつつ効率的に学習できます。
-
前処理: データローダーは、正規化やリサイズなどの前処理を一元管理できます。これにより、一貫した前処理が保証され、コードの重複も避けられます。
-
並列処理: 複数のワーカーを使って並列にデータを読み込むことで、I/Oがボトルネックになるのを防ぎます。
-
バックエンド抽象化: burnのデータローダーは、異なるバックエンド(CPU/GPU)に対して同じインターフェイスを提供します。これにより、コードを大幅に変更することなく、異なるハードウェアでモデルをトレーニングできます。
発展的なトピック
カスタムデータ拡張(Data Augmentation)
実際のプロジェクトでは、データ拡張を使って訓練データのバリエーションを増やすことがよくあります。burnでもデータローダーを拡張して、回転やフリップなどの拡張を行うことができます:
// 画像の前処理に拡張を追加する例
.map(|tensor| {
if rand::random::<f32>() > 0.5 {
// 50%の確率で画像を水平方向に反転
tensor.flip(2)
} else {
tensor
}
})
まとめ
つぎはトレーニングについて書いてみたいと思います