LoginSignup
0
0

BurnでONNXのモデルを実行

Posted at

Rust用のディープラーニングフレームワーク BurnONNX(Open Neural Network eXchange) 形式のモデルを実行してみました。

はじめに

以前、tch-rs の際に使ったモデルを ONNX 形式でエクスポートして使います。

PyTorch では torch.onnx.export でエクスポートできますが、第二引数で入力データの形状を指定する必要がありました。

注意点として、入力データの形状を 1次元配列(例. dummy_input = torch.zeros(2))にしたところ、Burn 側の実行時にエラーが発生して上手くいきませんでした。

そこで、下記ではバッチサイズ 1 の 2次元配列 (1 x 2) にしています。

sample.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.onnx

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc1 = nn.Linear(2, 3)
        self.fc2 = nn.Linear(3, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)

net = Net()

net.fc1.weight = nn.Parameter(torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]))
net.fc1.bias = nn.Parameter(torch.tensor([0.7, 0.8, 0.9]))

net.fc2.weight = nn.Parameter(torch.tensor([[0.4, 0.3, 0.2]]))
net.fc2.bias = nn.Parameter(torch.tensor([0.1]))

net.eval()

# 入力データの形状を指定するためのもの
dummy_input = torch.zeros((1, 2))
# ONNX フォーマットでエクスポート
torch.onnx.export(net, dummy_input, 'sample.onnx')

エクスポート結果の sample.onnx ファイルを ONNX 可視化ツールの NETRON で確認したところ次のようになりました。

onnx1.png

Burn で ONNX モデル実行

Burn で ONNX モデルを利用するには、基本的にこのような使い方になるようです。

  1. Rust のビルドスクリプトで ONNX モデルをインポート
  2. インポート結果を使って処理を実装

ビルドスクリプト(build-dependencies で設定)の依存関係も含め、Cargo.toml はこのようになりました。

Cargo.toml
[package]
name = "sample"
version = "0.1.0"
edition = "2021"

[dependencies]
burn = "0.12"
burn-ndarray = "0.12"

[build-dependencies]
burn-import = "0.12"

なお、Burn は wgpu にも対応していますが、今回は CPU で処理します。

ONNX モデルインポート

Rust のビルドスクリプト(ビルドの前処理)を使って、sample.onnx をインポートして model ディレクトリへ出力します。

build.rs
use burn_import::onnx::ModelGen;

fn main() {
    ModelGen::new()
        .input("sample.onnx")
        .out_dir("model/")
        .run_from_script();
}

この結果、target/debug/build/sample-f99bae906abcb7d4/out/model ディレクトリへ下記ファイルが生成されました。

  • sample.rs
  • sample.mpk

sample.rs にモデルを Burn 用に変換した型が定義され、sample.mpk に重みやバイアスのパラメータが出力されているようです。

インポートしたモデル利用

ONNX のインポート結果を使って処理を実装します。

モジュール定義

まずは、インポートで生成された sample.rs の内容を取り込んでモジュール定義します。

src/model/mod.rs
pub mod sample {
    include!(concat!(env!("OUT_DIR"), "/model/sample.rs"));
}

処理の実装

上記モジュールの model::sample::Model を使って処理します。

Model::default() を使用する事で .mpk ファイルから重みやバイアスを復元します。代わりに Model::new(&device) を使うとランダムな値になってしまうのでご注意ください。

src/main.rs
use burn::tensor::Tensor;
use burn_ndarray::{NdArray, NdArrayDevice};

mod model;
use model::sample::Model;

type Backend = NdArray<f32>;

fn main() {
    let device = NdArrayDevice::default();
    let model: Model<Backend> = Model::default();

    let input = Tensor::<Backend, 2>::from_floats([[1.0, 2.0]], &device);

    println!("{}", input);

    let output = model.forward(input);

    println!("{}", output);
}

実行結果

実行結果はこのようになりました。

$ cargo run
...
Tensor {
  data:
[[1.0, 2.0]],
  shape:  [1, 2],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}
Tensor {
  data:
[[1.6700001]],
  shape:  [1, 1],
  device:  Cpu,
  backend:  "ndarray",
  kind:  "Float",
  dtype:  "f32",
}

今回の処理結果は 1.6700001 で、以前 の tch-rs による結果は 1.6700 だったので問題は無さそうです。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0