従来のPythonを使った方法とは異なり、「安全かつ高速」なRustで機械学習モデルを動かす方法を解説します。
今回は、ONNX RuntimeのRustバインディングである ort クレートを使います。
ONNXとは?
説明するまでもないかもしれませんが、ONNX (Open Neural Network Exchange) は、異なるフレームワーク間で機械学習モデルを相互運用するためのオープンな標準フォーマットです。
特徴としては、
- フレームワークに非依存: PyTorch, TensorFlowなどから変換可能
- ランタイムに非依存: 様々な環境で動作
一度ONNXに変換すれば、推論時のランタイムは自由に選べるという点が最大のメリットでしょう。
ONNX Runtimeとは?
今回扱うONNX Runtimeは、Microsoftが開発する、推論に最適化されたクロスプラットフォームエンジンです。
なぜRustで使うのか?
Rustは以下の点で推論タスクに最適です:
- メモリ安全性: SegmentationFaultなしで安定して動作
- ゼロコスト抽象化: 高レベルAPIでも実行速度が落ちない
- 小さなバイナリ: 依存関係が少なく、エッジデバイスに最適
- 並行処理: Tokioなどで効率的な並列推論が可能
今回は ortクレート を使ってONNX Runtimeを扱います。
実験として、yolo v8のモデルを用いて、画像に対して推論を行います。
セットアップ
1. 依存関係
Cargo.tomlに以下を追加します。今回は画像処理用に image クレート、数値計算用に ndarray クレートも使用しています。
[dependencies]
ort = { version = "2.0.0-rc.9", features = ["download-binaries", "ndarray"] }
image = "0.24"
ndarray = "0.16"
-
"download-binaries"を書くとONNX Runtime の 事前ビルド済みバイナリを自動でダウンロードしてくれます。オフにすると自分でソースからビルドする必要があります。 -
"ndarray"機能を有効にすることで、ndarrayとの互換性をもたせれます。
実装
1. import
use image::{imageops::FilterType, GenericImageView, Pixel};
use ndarray::{Array, Axis};
use ort::session::{builder::GraphOptimizationLevel, Session};
use std::path::Path;
// バウンディングボックスの構造体
#[derive(Debug, Clone, Copy)]
struct BoundingBox {
x1: f32,
y1: f32,
x2: f32,
y2: f32,
}
// IoU計算
fn intersection(box1: &BoundingBox, box2: &BoundingBox) -> f32 {
let x_overlap = (box1.x2.min(box2.x2) - box1.x1.max(box2.x1)).max(0.0);
let y_overlap = (box1.y2.min(box2.y2) - box1.y1.max(box2.y1)).max(0.0);
x_overlap * y_overlap
}
// 2つのボックスの合計面積(重なりを除く)
fn union(box1: &BoundingBox, box2: &BoundingBox) -> f32 {
let area1 = (box1.x2 - box1.x1) * (box1.y2 - box1.y1);
let area2 = (box2.x2 - box2.x1) * (box2.y2 - box2.y1);
area1 + area2 - intersection(box1, box2)
}
2. メイン関数とモデルのロード
fn main() -> ort::Result<()> {
// モデルのロード
let model_path = "yolov8m.onnx";
if !Path::new(model_path).exists() {
return Err(ort::Error::new(format!("モデルパスが見つかりません!!!: {}", model_path)));
}
// セッションの作成
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(4)?
.commit_from_file(model_path)?;
Session::builder() で推論セッションを作成します。最適化レベルやスレッド数を指定できます。
3. 画像の前処理
YOLOv8は通常、640x640で正規化されたRGB画像を入力として受け取ります。
そのためここでは、事前に画像の処理を行います。
// 画像の読み込みと前処理
let image_path = "dog.jpg"; // サンプル画像を用意
if !Path::new(image_path).exists() {
return Err(ort::Error::new(format!("画像パスが見つかりません!!!: {}", image_path)));
}
let original_img = image::open(image_path).map_err(|e| ort::Error::new(format!("画像の読み込みに失敗!!!: {}", e)))?;
let (width, height) = original_img.dimensions();
// 640x640にリサイズする
let img = original_img.resize_exact(640, 640, FilterType::CatmullRom);
// NCHWに変換
let mut input = Array::zeros((1, 3, 640, 640));
for pixel in img.pixels() {
let x = pixel.0 as usize;
let y = pixel.1 as usize;
let channels = pixel.2.channels();
let r = channels[0];
let g = channels[1];
let b = channels[2];
// 値は 0.0 - 1.0 に正規化する
input[[0, 0, y, x]] = (r as f32) / 255.0;
input[[0, 1, y, x]] = (g as f32) / 255.0;
input[[0, 2, y, x]] = (b as f32) / 255.0;
}
4. 推論の実行
// 推論実行
let input_value = ort::value::Value::from_array(input)?;
let outputs = session.run(ort::inputs!["images" => input_value]?)?;
// 出力を取得
let output = outputs["output0"].try_extract_tensor::<f32>()?;
let output_view = output.view();
// 出力を変換
let output_view = output_view.permuted_axes(&[0, 2, 1][..]);
let output_view = output_view.index_axis(Axis(0), 0);
println!("推論結果: {:?}", output_view.shape());
5. 後処理と結果の表示
複数検出への対処として、NMS(Non-Maximum Suppression) で重複を除去します。
検出されたバウンディングボックスを描画して保存します。
// バウンディングボックスの抽出
let mut boxes = Vec::new();
let conf_threshold = 0.5; // 信頼度のしきい値
for row in output_view.outer_iter() {
let row_owned = row.to_owned();
let row_slice = row_owned.as_slice().unwrap();
let box_data = &row_slice[0..4]; // ボックスの座標
let scores = &row_slice[4..]; // スコア
let (class_id, &score) = scores.iter().enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap();
// スコアがしきい値以上のものだけを検出結果として保存
if score > conf_threshold {
let cx = box_data[0];
let cy = box_data[1];
let w = box_data[2];
let h = box_data[3];
// 座標を元画像サイズに復元
let x1 = (cx - w / 2.0) * (width as f32 / 640.0);
let y1 = (cy - h / 2.0) * (height as f32 / 640.0);
let x2 = (cx + w / 2.0) * (width as f32 / 640.0);
let y2 = (cy + h / 2.0) * (height as f32 / 640.0);
boxes.push((
BoundingBox { x1, y1, x2, y2 },
class_id,
score,
));
}
}
// NMS
boxes.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap()); // スコアの降順でソート
let mut result = Vec::new();
while !boxes.is_empty() {
result.push(boxes[0]);
boxes = boxes
.iter()
.filter(|box1| {
let iou = intersection(&boxes[0].0, &box1.0) / union(&boxes[0].0, &box1.0);
iou < 0.7 // IoUが0.7未満のものだけ残す
})
.copied()
.collect();
}
println!("検出数: {} (NMS適用後)", result.len());
// 描画
let mut img_to_draw = original_img.to_rgb8();
for (bbox, class_id, score) in result {
println!("クラスID: {}, スコア: {:.2}, ボックス: ({:.1}, {:.1}, {:.1}, {:.1})",
class_id, score, bbox.x1, bbox.y1, bbox.x2, bbox.y2);
draw_rect(&mut img_to_draw, bbox.x1 as i32, bbox.y1 as i32, bbox.x2 as i32, bbox.y2 as i32);
}
img_to_draw.save("output.jpg").map_err(|e| ort::Error::new(format!("画像の保存に失敗!!!: {}", e)))?;
println!("保存: output.jpg");
Ok(())
}
// 矩形描画関数
fn draw_rect(img: &mut image::RgbImage, x1: i32, y1: i32, x2: i32, y2: i32) {
let color = image::Rgb([255, 0, 0]); // 赤線
let (w, h) = img.dimensions();
let thickness = 10; // 線の太さ
// 上下描画
for t in 0..thickness {
for x in x1..=x2 {
if x >= 0 && x < w as i32 {
// 上の線
let y_top = y1 + t;
if y_top >= 0 && y_top < h as i32 {
img.put_pixel(x as u32, y_top as u32, color);
}
// 下の線
let y_bottom = y2 - t;
if y_bottom >= 0 && y_bottom < h as i32 {
img.put_pixel(x as u32, y_bottom as u32, color);
}
}
}
}
// 左右描画
for t in 0..thickness {
for y in y1..=y2 {
if y >= 0 && y < h as i32 {
// 左の線
let x_left = x1 + t;
if x_left >= 0 && x_left < w as i32 {
img.put_pixel(x_left as u32, y as u32, color);
}
// 右の線
let x_right = x2 - t;
if x_right >= 0 && x_right < w as i32 {
img.put_pixel(x_right as u32, y as u32, color);
}
}
}
}
}
実行方法・結果
- モデルと画像を用意します。
- 実行します。
成功すると以下のように、画像に検出結果が描画され保存されるはずです。
最後に
この記事では、ortクレートを使用してYOLOv8をRustで推論する方法を紹介しました。
Rust + ONNX Runtime (ortクレート) は、以下の3つの強みを持つ推論環境です:
- 高速
- 安全
- 軽量
RustはPythonと比較しても強力な推論環境となれるのではないでしょうか。
この記事ではortのチュートリアル的な内容ですが、今後の記事でより複雑な処理を組み込んだ内容を書きたいと思っています。
