機械学習用の WASI API である wasi-nn を使って、ONNX 形式のニューラルネットワークを処理する WebAssembly コンポーネントを作ってみました。
はじめに
事前準備をいくつか行います。
wasmtime ビルド
今回は wasmtime で実行します。
https://github.com/bytecodealliance/wasmtime/releases で配布されているバイナリでは必要な機能が有効化されておらず実行できなかったので、--features
オプションを使って手動ビルドしました。
$ git clone https://github.com/bytecodealliance/wasmtime
$ cd wasmtime
$ cargo build --release --features component-model,wasi-nn,wasmtime-wasi-nn/onnx
wit-deps CLI インストール
wit-deps を使って WIT の依存関係を解決する場合、次のようにしてインストールします。
$ cargo install wit-deps-cli
これで、wit/deps.toml
ファイルへ記載した WIT の依存関係を wit-deps
コマンドで処理できます。
ONNX モデル作成
ここでは、「BurnでONNXのモデルを実行」と同じ方法で ONNX モデルを作成して使用します。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.onnx
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(2, 3)
self.fc2 = nn.Linear(3, 1)
def forward(self, x):
x = F.relu(self.fc1(x))
return self.fc2(x)
net = Net()
net.fc1.weight = nn.Parameter(torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]))
net.fc1.bias = nn.Parameter(torch.tensor([0.7, 0.8, 0.9]))
net.fc2.weight = nn.Parameter(torch.tensor([[0.4, 0.3, 0.2]]))
net.fc2.bias = nn.Parameter(torch.tensor([0.1]))
net.eval()
dummy_input = torch.zeros((1, 2))
torch.onnx.export(net, dummy_input, 'model/sample.onnx', input_names=['input'], output_names=['output'])
wasi-nn では入力名と出力名を指定する必要があるため、input
と output
になるよう変更を加えました。
コマンドライン実行用コンポーネント作成
まずは、コマンドライン実行する WebAssembly コンポーネントを作成します。
wasmtime-wasi-nn
クレートを使って実装する方法もあるようですが、ここでは使わずに wit-bindgen で実装してみました。
WIT 定義
このコンポーネントの WIT ファイルはこのようにしました。
package sample:eval;
world eval {
import wasi:nn/graph@0.2.0-rc-2024-08-19;
import wasi:nn/tensor@0.2.0-rc-2024-08-19;
export wasi:cli/run@0.2.3;
}
wasi-nn や cli のバージョンは、現時点で wasmtime がサポートしているバージョンに合わせています。1
WIT 依存関係
wit/deps.toml
へ依存する WIT のアーカイブファイル2を指定します。
cli = "https://github.com/WebAssembly/wasi-cli/archive/refs/tags/v0.2.3.tar.gz"
nn = "https://github.com/WebAssembly/wasi-nn/archive/refs/tags/0.2.0-rc-2024-08-19.tar.gz"
wit-deps update
を実行すれば、依存する WIT ファイルが wit/deps ディレクトリへ配置されます。
$ wit-deps update
実装
wasi-nn の使い方は次のようになります。
- load でモデルをロード
- init_execution_context で実行コンテキストを取得
- set_input で入力データを設定
- compute で実行
- get_output で結果を取得
ONNX形式のモデルをCPUで実行する場合の実装例はこうなります。
なお、バイトオーダーはリトルエンディアンにします。
use exports::wasi::cli::run::Guest;
use wasi::nn::graph::{ExecutionTarget, GraphEncoding, load};
use wasi::nn::tensor::{Tensor, TensorType};
use std::fs;
wit_bindgen::generate!({
world: "eval",
generate_all
});
struct Host;
impl Guest for Host {
fn run() -> Result<(), ()> {
let res = evaluate().map_err(|e| {
println!("ERROR: {}", e);
()
})?;
println!("output={}", res);
Ok(())
}
}
// wasi-nn 処理
fn evaluate() -> Result<f32, Box<dyn std::error::Error>> {
let model = fs::read("model/sample.onnx")?;
let err_handler = |e: wasi::nn::errors::Error| format!("{:?}", e);
let graph = load(&[model], GraphEncoding::Onnx, ExecutionTarget::Cpu).map_err(err_handler)?;
let ctx = graph.init_execution_context().map_err(err_handler)?;
// 入力の設定
let data = [1.1f32.to_le_bytes(), 2.2f32.to_le_bytes()].concat();
let tensor = Tensor::new(&[1, 2], TensorType::Fp32, &data);
ctx.set_input("input", tensor).map_err(err_handler)?;
ctx.compute().map_err(err_handler)?;
// 出力の取得
let output = ctx.get_output("output").map_err(err_handler)?;
let output_data: [u8; 4] = output.data().try_into().map_err(|e| format!("{:?}", e))?;
Ok(f32::from_le_bytes(output_data))
}
export!(Host);
[dependencies]
wit-bindgen = "0.41.0"
[lib]
crate-type = ["cdylib"]
ビルド
wasm32-wasip2
をターゲット指定してビルドします。
$ cargo build --release --target wasm32-wasip2
実行
wasmtime で実行します。
wasi-nn を適用するには --wasi nn
(もしくは -S nn
)のオプション指定が必要です。
--dir
オプションを使うと、ホストOS側のディレクトリとWebAssembly側のディレクトリをマッピングできます。
例えば、親ディレクトリの model ディレクトリを(WebAssembly側で)カレントの model ディレクトリへマッピングする場合はこのようになります。
$ wasmtime --wasi nn --dir ../model::./model target/wasm32-wasip2/release/sample1.wasm
output=1.7570001
ちなみに、カレントディレクトリへ model ディレクトリを配置して実行する場合は wasmtime --wasi nn --dir . target/・・・
となります。
Webサーバー実行用コンポーネント作成
次に、HTTPハンドラとしてWebサーバー実行する WebAssembly コンポーネントを作成します。
world.wit や deps.toml は wasi-cli の箇所を wasi-http へ変更するだけです。
package sample:eval;
world eval {
import wasi:nn/graph@0.2.0-rc-2024-08-19;
import wasi:nn/tensor@0.2.0-rc-2024-08-19;
export wasi:http/incoming-handler@0.2.3;
}
http = "https://github.com/WebAssembly/wasi-http/archive/refs/tags/v0.2.3.tar.gz"
nn = "https://github.com/WebAssembly/wasi-nn/archive/refs/tags/0.2.0-rc-2024-08-19.tar.gz"
HTTP ハンドラとして実装すると例えばこのようになります。
リクエストの度に ONNX モデルを毎回ロードするのは非効率なので、ここでは LazyLock
3を使って初回だけロードするようにしました。
use exports::wasi::http::incoming_handler::{Guest, IncomingRequest, ResponseOutparam};
use wasi::http::types::{ErrorCode, Headers, OutgoingResponse};
use wasi::nn::graph::{ExecutionTarget, Graph, GraphEncoding, load};
use wasi::nn::tensor::{Tensor, TensorType};
use std::collections::HashMap;
use std::fs;
use std::sync::LazyLock;
wit_bindgen::generate!({
world: "eval",
generate_all
});
struct Host;
impl Guest for Host {
fn handle(request: IncomingRequest, response_out: ResponseOutparam) -> () {
let res = extract_params(&request)
.and_then(|(a, b)| evaluate(a, b))
.and_then(|x| to_output(&x))
.map_err(|e| ErrorCode::InternalError(Some(e.to_string())));
ResponseOutparam::set(response_out, res);
()
}
}
type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
static GRAPH: LazyLock<Graph> = LazyLock::new(|| {
let model = fs::read("model/sample.onnx").unwrap();
load(&[model], GraphEncoding::Onnx, ExecutionTarget::Cpu).unwrap()
});
fn extract_params(request: &IncomingRequest) -> Result<(f32, f32)> {
let params = to_params(request)?;
let a = get_to_f32(¶ms, "a")?;
let b = get_to_f32(¶ms, "b")?;
Ok((a, b))
}
fn to_params(request: &IncomingRequest) -> Result<HashMap<String, String>> {
let path_query = request.path_with_query().ok_or("failed path")?;
let (_, query) = path_query.split_once('?').ok_or("no querystring")?;
let res = query
.split('&')
.filter_map(|x| x.split_once('='))
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect::<HashMap<String, String>>();
Ok(res)
}
fn get_to_f32(params: &HashMap<String, String>, key: &str) -> Result<f32> {
let v = params.get(key).ok_or(format!("no params: {}", key))?;
v.parse::<f32>().map_err(|e| e.into())
}
fn evaluate(a: f32, b: f32) -> Result<f32> {
let err_handler = |e: wasi::nn::errors::Error| format!("{:?}", e);
let ctx = GRAPH.init_execution_context().map_err(err_handler)?;
let data = [a.to_le_bytes(), b.to_le_bytes()].concat();
let tensor = Tensor::new(&[1, 2], TensorType::Fp32, &data);
ctx.set_input("input", tensor).map_err(err_handler)?;
ctx.compute().map_err(err_handler)?;
let output = ctx.get_output("output").map_err(err_handler)?;
let output_data: [u8; 4] = output.data().try_into().map_err(|e| format!("{:?}", e))?;
let res = f32::from_le_bytes(output_data);
Ok(res)
}
fn to_output(value: &f32) -> Result<OutgoingResponse> {
let res = format!(r#"{{ "result": {} }}"#, value);
let h = Headers::new();
h.append("content-length", res.len().to_string().as_bytes())?;
let r = OutgoingResponse::new(h);
let b = r.body().map_err(|_| "failed outgoing body")?;
let w = b.write().map_err(|_| "failed outgoing write")?;
w.write(res.as_bytes())?;
w.flush()?;
Ok(r)
}
export!(Host);
ビルドと実行
$ cargo build --release --target wasm32-wasip2
Webサーバーとして実行するには wasmtime の serve
コマンドを使います。
--wasi nn
に加えて --wasi cli
オプションも指定する必要がありました。4
$ wasmtime serve --wasi cli --wasi nn --dir ../model::./model target/wasm32-wasip2/release/sample2.wasm
Serving HTTP on http://0.0.0.0:8080/
実行結果はこのようになります。
$ curl "http://127.0.0.1:8080/?a=1.5&b=2.3"
{ "result": 1.8810002 }