0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

wasi-nnでONNX形式ニューラルネットを処理するWebAssemblyコンポーネント作成

Posted at

機械学習用の WASI API である wasi-nn を使って、ONNX 形式のニューラルネットワークを処理する WebAssembly コンポーネントを作ってみました。

はじめに

事前準備をいくつか行います。

wasmtime ビルド

今回は wasmtime で実行します。

https://github.com/bytecodealliance/wasmtime/releases で配布されているバイナリでは必要な機能が有効化されておらず実行できなかったので、--features オプションを使って手動ビルドしました。

ONNXサポート有効化ビルド
$ git clone https://github.com/bytecodealliance/wasmtime
$ cd wasmtime
$ cargo build --release --features component-model,wasi-nn,wasmtime-wasi-nn/onnx

wit-deps CLI インストール

wit-deps を使って WIT の依存関係を解決する場合、次のようにしてインストールします。

wit-deps-cliインストール
$ cargo install wit-deps-cli

これで、wit/deps.toml ファイルへ記載した WIT の依存関係を wit-deps コマンドで処理できます。

ONNX モデル作成

ここでは、「BurnでONNXのモデルを実行」と同じ方法で ONNX モデルを作成して使用します。

create_onnx.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.onnx

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc1 = nn.Linear(2, 3)
        self.fc2 = nn.Linear(3, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)

net = Net()

net.fc1.weight = nn.Parameter(torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]))
net.fc1.bias = nn.Parameter(torch.tensor([0.7, 0.8, 0.9]))

net.fc2.weight = nn.Parameter(torch.tensor([[0.4, 0.3, 0.2]]))
net.fc2.bias = nn.Parameter(torch.tensor([0.1]))

net.eval()

dummy_input = torch.zeros((1, 2))

torch.onnx.export(net, dummy_input, 'model/sample.onnx', input_names=['input'], output_names=['output'])

wasi-nn では入力名と出力名を指定する必要があるため、inputoutput になるよう変更を加えました。

コマンドライン実行用コンポーネント作成

まずは、コマンドライン実行する WebAssembly コンポーネントを作成します。

wasmtime-wasi-nn クレートを使って実装する方法もあるようですが、ここでは使わずに wit-bindgen で実装してみました。

WIT 定義

このコンポーネントの WIT ファイルはこのようにしました。

wit/world.wit
package sample:eval;

world eval {
    import wasi:nn/graph@0.2.0-rc-2024-08-19;
    import wasi:nn/tensor@0.2.0-rc-2024-08-19;

    export wasi:cli/run@0.2.3;
}

wasi-nn や cli のバージョンは、現時点で wasmtime がサポートしているバージョンに合わせています。1

WIT 依存関係

wit/deps.toml へ依存する WIT のアーカイブファイル2を指定します。

/wit/deps.toml
cli = "https://github.com/WebAssembly/wasi-cli/archive/refs/tags/v0.2.3.tar.gz"
nn = "https://github.com/WebAssembly/wasi-nn/archive/refs/tags/0.2.0-rc-2024-08-19.tar.gz"

wit-deps update を実行すれば、依存する WIT ファイルが wit/deps ディレクトリへ配置されます。

WITの依存関係解決
$ wit-deps update

実装

wasi-nn の使い方は次のようになります。

  1. load でモデルをロード
  2. init_execution_context で実行コンテキストを取得
  3. set_input で入力データを設定
  4. compute で実行
  5. get_output で結果を取得

ONNX形式のモデルをCPUで実行する場合の実装例はこうなります。
なお、バイトオーダーはリトルエンディアンにします。

src/lib.rs
use exports::wasi::cli::run::Guest;
use wasi::nn::graph::{ExecutionTarget, GraphEncoding, load};
use wasi::nn::tensor::{Tensor, TensorType};

use std::fs;

wit_bindgen::generate!({
    world: "eval",
    generate_all
});

struct Host;

impl Guest for Host {
    fn run() -> Result<(), ()> {
        let res = evaluate().map_err(|e| {
            println!("ERROR: {}", e);
            ()
        })?;

        println!("output={}", res);

        Ok(())
    }
}
// wasi-nn 処理
fn evaluate() -> Result<f32, Box<dyn std::error::Error>> {
    let model = fs::read("model/sample.onnx")?;

    let err_handler = |e: wasi::nn::errors::Error| format!("{:?}", e);

    let graph = load(&[model], GraphEncoding::Onnx, ExecutionTarget::Cpu).map_err(err_handler)?;
    let ctx = graph.init_execution_context().map_err(err_handler)?;

    // 入力の設定
    let data = [1.1f32.to_le_bytes(), 2.2f32.to_le_bytes()].concat();
    let tensor = Tensor::new(&[1, 2], TensorType::Fp32, &data);
    ctx.set_input("input", tensor).map_err(err_handler)?;

    ctx.compute().map_err(err_handler)?;

    // 出力の取得
    let output = ctx.get_output("output").map_err(err_handler)?;
    let output_data: [u8; 4] = output.data().try_into().map_err(|e| format!("{:?}", e))?;

    Ok(f32::from_le_bytes(output_data))
}

export!(Host);
Cargo.toml
[dependencies]
wit-bindgen = "0.41.0"

[lib]
crate-type = ["cdylib"]

ビルド

wasm32-wasip2 をターゲット指定してビルドします。

ビルド
$ cargo build --release --target wasm32-wasip2 

実行

wasmtime で実行します。

wasi-nn を適用するには --wasi nn(もしくは -S nn)のオプション指定が必要です。

--dir オプションを使うと、ホストOS側のディレクトリとWebAssembly側のディレクトリをマッピングできます。
例えば、親ディレクトリの model ディレクトリを(WebAssembly側で)カレントの model ディレクトリへマッピングする場合はこのようになります。

実行結果
$ wasmtime --wasi nn --dir ../model::./model  target/wasm32-wasip2/release/sample1.wasm
output=1.7570001

ちなみに、カレントディレクトリへ model ディレクトリを配置して実行する場合は wasmtime --wasi nn --dir . target/・・・ となります。

Webサーバー実行用コンポーネント作成

次に、HTTPハンドラとしてWebサーバー実行する WebAssembly コンポーネントを作成します。

world.wit や deps.toml は wasi-cli の箇所を wasi-http へ変更するだけです。

wit/world.wit
package sample:eval;

world eval {
    import wasi:nn/graph@0.2.0-rc-2024-08-19;
    import wasi:nn/tensor@0.2.0-rc-2024-08-19;

    export wasi:http/incoming-handler@0.2.3;
}
wit/deps.toml
http = "https://github.com/WebAssembly/wasi-http/archive/refs/tags/v0.2.3.tar.gz"
nn = "https://github.com/WebAssembly/wasi-nn/archive/refs/tags/0.2.0-rc-2024-08-19.tar.gz"

HTTP ハンドラとして実装すると例えばこのようになります。

リクエストの度に ONNX モデルを毎回ロードするのは非効率なので、ここでは LazyLock 3を使って初回だけロードするようにしました。

src/lib.rs
use exports::wasi::http::incoming_handler::{Guest, IncomingRequest, ResponseOutparam};
use wasi::http::types::{ErrorCode, Headers, OutgoingResponse};
use wasi::nn::graph::{ExecutionTarget, Graph, GraphEncoding, load};
use wasi::nn::tensor::{Tensor, TensorType};

use std::collections::HashMap;
use std::fs;
use std::sync::LazyLock;

wit_bindgen::generate!({
    world: "eval",
    generate_all
});

struct Host;

impl Guest for Host {
    fn handle(request: IncomingRequest, response_out: ResponseOutparam) -> () {
        let res = extract_params(&request)
            .and_then(|(a, b)| evaluate(a, b))
            .and_then(|x| to_output(&x))
            .map_err(|e| ErrorCode::InternalError(Some(e.to_string())));

        ResponseOutparam::set(response_out, res);

        ()
    }
}

type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;

static GRAPH: LazyLock<Graph> = LazyLock::new(|| {
    let model = fs::read("model/sample.onnx").unwrap();
    load(&[model], GraphEncoding::Onnx, ExecutionTarget::Cpu).unwrap()
});

fn extract_params(request: &IncomingRequest) -> Result<(f32, f32)> {
    let params = to_params(request)?;

    let a = get_to_f32(&params, "a")?;
    let b = get_to_f32(&params, "b")?;

    Ok((a, b))
}

fn to_params(request: &IncomingRequest) -> Result<HashMap<String, String>> {
    let path_query = request.path_with_query().ok_or("failed path")?;
    let (_, query) = path_query.split_once('?').ok_or("no querystring")?;

    let res = query
        .split('&')
        .filter_map(|x| x.split_once('='))
        .map(|(k, v)| (k.to_string(), v.to_string()))
        .collect::<HashMap<String, String>>();

    Ok(res)
}

fn get_to_f32(params: &HashMap<String, String>, key: &str) -> Result<f32> {
    let v = params.get(key).ok_or(format!("no params: {}", key))?;

    v.parse::<f32>().map_err(|e| e.into())
}

fn evaluate(a: f32, b: f32) -> Result<f32> {
    let err_handler = |e: wasi::nn::errors::Error| format!("{:?}", e);

    let ctx = GRAPH.init_execution_context().map_err(err_handler)?;

    let data = [a.to_le_bytes(), b.to_le_bytes()].concat();
    let tensor = Tensor::new(&[1, 2], TensorType::Fp32, &data);
    ctx.set_input("input", tensor).map_err(err_handler)?;

    ctx.compute().map_err(err_handler)?;

    let output = ctx.get_output("output").map_err(err_handler)?;
    let output_data: [u8; 4] = output.data().try_into().map_err(|e| format!("{:?}", e))?;

    let res = f32::from_le_bytes(output_data);

    Ok(res)
}

fn to_output(value: &f32) -> Result<OutgoingResponse> {
    let res = format!(r#"{{ "result": {} }}"#, value);

    let h = Headers::new();
    h.append("content-length", res.len().to_string().as_bytes())?;

    let r = OutgoingResponse::new(h);

    let b = r.body().map_err(|_| "failed outgoing body")?;
    let w = b.write().map_err(|_| "failed outgoing write")?;

    w.write(res.as_bytes())?;
    w.flush()?;

    Ok(r)
}

export!(Host);

ビルドと実行

ビルド
$ cargo build --release --target wasm32-wasip2 

Webサーバーとして実行するには wasmtime の serve コマンドを使います。
--wasi nn に加えて --wasi cli オプションも指定する必要がありました。4

サーバー実行
$ wasmtime serve --wasi cli --wasi nn --dir ../model::./model target/wasm32-wasip2/release/sample2.wasm
Serving HTTP on http://0.0.0.0:8080/

実行結果はこのようになります。

実行例
$ curl "http://127.0.0.1:8080/?a=1.5&b=2.3"
{ "result": 1.8810002 }
  1. 最新バージョンを使用すると wasmtime 実行時にエラーが発生しました

  2. wit ディレクトリを含む tar.gz ファイルの URL を指定。zipファイルだとエラーになった

  3. OnceLock でも代用可能

  4. wasi:cli/environment の import を解決するため wasi-cli を適用する必要があった

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?