ort という Rust の ONNX ランタイムラッパを動かしてみた記録です。
vgg16 で画像のクラス分けを行います。
注意: 本記事の内容は ort 1.16 時点のものです。現在 2.0 の開発が進められており、仕様が大きく変わっているようです。
開発環境
rust のプロジェクトを作成し、Cargo.toml に以下を追加します。
[dependencies]
image = "0.25.1"
ndarray = "0.15.6"
ort = "1.16"
https://github.com/onnx/models/tree/main/validated/vision/classification/vgg/model から vgg16-12.onnx
を頂きます。
また、ラベルの一覧を入手します。
以上です。onnxのランタイムはビルド時に環境に合ったものが自動的にダウンロードされます。
実行コード
以下の通りです。
use std::{
fs::File,
io::{BufRead, BufReader},
};
use image::{DynamicImage, GenericImageView};
use ndarray::{s, Array, Array3, CowArray, NewAxis};
use ort::{Environment, GraphOptimizationLevel, SessionBuilder, Value};
fn run(model: &str, labels: &str, image: &str) -> Result<(), Box<dyn std::error::Error>> {
// モデルの準備する。
let environment = Environment::builder().build()?.into_arc();
let session = SessionBuilder::new(&environment)?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(4)?
.with_model_from_file(model)?;
// ラベルを読み込む。
let labels = BufReader::new(File::open(labels)?)
.lines()
.collect::<Result<Vec<_>, _>>()?;
// 入力を準備する。
let input_image = image::open(image)?;
let input_image = resize_and_center_crop(&input_image, 224, 224);
let input = image_to_normalized_ndarray(&input_image);
// 推論する。
let input = CowArray::from(input.slice(s![NewAxis, .., .., ..]).into_dyn());
let inputs = vec![Value::from_array(session.allocator(), &input)?];
let outputs: Vec<Value> = session.run(inputs)?;
let output = outputs[0].try_extract::<f32>()?;
let output = output.view();
// Soft-max
let output = output.map(|x| x.exp());
let output = &output / output.sum();
// 結果を表示する。
let mut output = output
.iter()
.copied()
.zip(labels.iter())
.collect::<Vec<_>>();
output.sort_by(|&(a, _), &(b, _)| a.partial_cmp(&b).unwrap().reverse());
for (prob, tag) in &output[..10] {
println!("{tag}: {prob}%", prob = prob * 100.0);
}
Ok(())
}
/// object-fit: cover
fn resize_and_center_crop(image: &DynamicImage, width: u32, height: u32) -> DynamicImage {
let image = image.resize_to_fill(width, height, image::imageops::FilterType::Triangle);
let image = image.crop_imm(
(image.width() - width) / 2,
(image.height() - height) / 2,
width,
height,
);
image
}
/// image の画像を ndarray に変換する。ついでに色の正規化を行う。
fn image_to_normalized_ndarray(image: &DynamicImage) -> Array3<f32> {
let mut result = Array::zeros((3, 224, 224));
for (x, y, pixel) in image.pixels() {
let x = x as usize;
let y = y as usize;
let [r, g, b, _] = pixel.0;
result[[0, y, x]] = (((r as f64) / 255. - 0.485) / 0.229) as f32;
result[[1, y, x]] = (((g as f64) / 255. - 0.456) / 0.224) as f32;
result[[2, y, x]] = (((b as f64) / 255. - 0.406) / 0.225) as f32;
}
result
}
fn main() {
run(
"vgg16-12.onnx",
"ILSVRC2014-labels.txt",
"1280px-野良猫(相模川沿い).jpg",
)
.unwrap();
}
結果
https://ja.wikipedia.org/wiki/%E3%83%8D%E3%82%B3#/media/%E3%83%95%E3%82%A1%E3%82%A4%E3%83%AB:%E9%87%8E%E8%89%AF%E7%8C%AB%EF%BC%88%E7%9B%B8%E6%A8%A1%E5%B7%9D%E6%B2%BF%E3%81%84%EF%BC%89.jpg
tabby: 27.493656%
tiger cat: 17.814444%
hare: 11.536144%
wood rabbit: 7.2499194%
fox squirrel: 6.9241986%
Egyptian cat: 5.8093266%
tiger: 1.9416237%
hamster: 1.0579933%
Shetland sheepdog: 1.0080262%
collie: 0.97883344%
うまく行ってそうです。
配布方法
cargo build --release
すると target/release
に {プロジェクト名}.exe
, onnxruntime.dll
が作成されます。 (windows の場合)
onnxファイルなどと一緒に、同じディレクトリに入れて配布しましょう。