12
0

概要

この記事では、LightGBMを用いた推論処理をWASMにしてブラウザ上で実行させる方法を紹介します。具体的には、軽量なビルドアセットが作成可能なTinyGoと全てGo言語で実装されているleavesを使い、軽量でランタイムやモデルファイルが内包されているWASMを作ります。

leaves

leavesは純粋なGo言語だけでGBRTモデルの予測コードを実装しているライブラリで、対応モデルはLightGBM、XGBoost、scikit-learnのtree modelです。今回はLightGBMの学習済みモデルを推論で利用させるために使用します。LightGBMのモデル読み込みは、txt形式とjson形式の2種類が対応しており、自分の環境ではLightGBMのモデルファイルサイズが小さいtxtの方を利用しました。また、標準でカテゴリ特徴量がサポートされていたり、バッチ推論もできるのが特徴量です。推論速度はREADMEによるとC言語実装と比較してもそれなりに高速みたいです。また、External (Unofficial) RepositoriesとしてLightGBM公式のREADMEに紹介されています。
https://github.com/dmitryikh/leaves
https://github.com/microsoft/LightGBM/tree/master

実装

全実装コードは下記にあります。
https://github.com/tsuno0829/tinygo_leaves

TinyGoの環境構築は公式のDockerイメージを利用しました。また公式のサンプルコードを大いに参考にさせて頂きました。
https://tinygo.org/getting-started/install/using-docker/
https://github.com/tinygo-org/tinygo/tree/release/src/examples/wasm

まずLightGBMのモデルファイルですが、今回は2値分類のモデルを作りました。
model.txtという名前で保存し、ファイルサイズは258KBでした。

main.py
from sklearn import datasets
from sklearn.model_selection import train_test_split
import lightgbm as lgb


def main():
    params = {
        'objective': 'binary',
        'metric': 'binary_logloss',
        'learning_rate': 0.1,
        'n_estimators': 100,
        'verbose': -1,
    }

    iris = datasets.load_breast_cancer()
    X, y = iris.data, iris.target
    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

    lgb_train = lgb.Dataset(X_train, y_train)
    model = lgb.train(params, lgb_train, valid_sets=[lgb_train])

    y_pred = model.predict(X_test)
    model.save_model('./go/model.txt')
    print(X_test[:1])
    print(y_test[:1])
    print(y_pred[:1])


if __name__ == '__main__':
    main()

出力結果は下記です。ログの最後には入力特徴量とその予測結果を出力しています。WASMにしたときに、同じ入出力が得られるかを確認します。

❯ docker compose run --rm dev-py
/usr/local/lib/python3.11/site-packages/lightgbm/engine.py:177: UserWarning: Found `n_estimators` in params. Will use it instead of argument
  _log_warning(f"Found `{alias}` in params. Will use it instead of argument")
[1]     training's binary_logloss: 0.58296
[2]     training's binary_logloss: 0.516194
︙
[100]   training's binary_logloss: 0.000462198
[[1.247e+01 1.860e+01 8.109e+01 4.819e+02 9.965e-02 1.058e-01 8.005e-02
  3.821e-02 1.925e-01 6.373e-02 3.961e-01 1.044e+00 2.497e+00 3.029e+01
  6.953e-03 1.911e-02 2.701e-02 1.037e-02 1.782e-02 3.586e-03 1.497e+01
  2.464e+01 9.605e+01 6.779e+02 1.426e-01 2.378e-01 2.671e-01 1.015e-01
  3.014e-01 8.750e-02]]
[1]
[0.99957539]

TinyGo×leavesのコードが下記です。
LightGBMのtxtファイルはgo:embedでパラメータをWASMに埋め込む形にしています。また、WASM起動時にmodelInitでleavesのモデルインスタンスを作成しています。inferenceは外部から呼び出される関数で今回はJSで呼び出されます。inferenceは1次元配列を受け取り、データ数と特徴量数の2次元配列にreshapeし、バッチ推論をして、予測値の配列を返す関数です。

wasm.go
package main

import (
	"bufio"
	"bytes"
	"embed"
	"syscall/js"
	"github.com/dmitryikh/leaves"
)

//go:embed model.txt
var fs embed.FS

var model *leaves.Ensemble

func modelInit() (error) {
	useTransformation := true
	data, err := fs.ReadFile("model.txt")
	if err != nil {
		return err
	}

	reader := bytes.NewReader(data)
	bufferedReader := bufio.NewReader(reader)
	model, err = leaves.LGEnsembleFromReader(bufferedReader, useTransformation)
	if err != nil {
		return err
	}

	return nil
}

func inference(this js.Value, args []js.Value) interface{} {
	// JSから渡される配列を取得
	inputJSArray := args[0]
	rows := args[1].Int()
	cols := args[2].Int() 

	// 配列の長さを取得
	length := inputJSArray.Length()

	// Goのfloat64スライスに変換
	input := make([]float64, length)
	for i := 0; i < length; i++ {
		input[i] = float64(inputJSArray.Index(i).Float())
	}

	// 推論処理
	predictions, err := inference_impl(input, rows, cols)
	if err != nil {
		panic(err)
	}

	// 予測結果をJavaScriptの配列に変換
	result := make([]interface{}, rows)
	for i, v := range predictions {
		result[i] = v
	}

	return js.ValueOf(result)
}

func inference_impl(v []float64, rows int, cols int) ([]float64, error) {
	predictions := make([]float64, rows)
	err := model.PredictDense(v, rows, cols, predictions, 0, 1)
	if err != nil {
		return predictions, err
	}
	return predictions, nil
}

func main() {
	err := modelInit()
	if err != nil {
		panic(err)
	}

	wait := make(chan struct{}, 0)
	js.Global().Set("inference", js.FuncOf(inference))
	<-wait
}

そしてこのコードをtinygo build -o ./html/wasm.wasm -target wasm -no-debug ./go/wasm.goすることで、WASMを作成することができます。今回作成されたWASMのサイズは497KBでした。モデルのファイルサイズが258KBだったのでランタイムなどのサイズが239KBとなっています。なかなか軽量です。

最後に作成したWASMの動作確認をブラウザでしてみます。htmlとjsのコードは下記です。jsのコードはWASMを読み込み、1秒後にリクエストを投げて、そのレスポンス結果を出力するコードです。(htmlにあるwasm_exec.jsはjsとwasmのグルーコードで、今回はTinyGoが提供しているjsファイルをそのままコピーして使用しています)

index.html
<!DOCTYPE html>

<html>

<head>
    <script src="wasm_exec.js" defer></script>
    <script src="wasm.js" defer></script>
</head>

</html>
wasm.js
const WASM_URL = 'wasm.wasm';

var wasm;
let go;
let wasmModule = null;

// WebAssembly.Moduleを格納する処理
function loadWasmModule() {
    fetch(WASM_URL).then(response =>
        response.arrayBuffer()
    ).then(bytes =>
        wasmModule = new WebAssembly.Module(bytes)
    );
};

function model_inference(input, rows, cols) {
    go = new Go();
    wasm = (new WebAssembly.Instance(wasmModule, go.importObject));
    go.run(wasm);

    const startTime = performance.now();
    let pred = window.inference(input, rows, cols);
    const endTime = performance.now();
    console.log((endTime - startTime) + 'ms'); // 何ミリ秒かかったかを表示

    return pred
}

setTimeout(() => {
    const input = [1.247e+01, 1.860e+01, 8.109e+01, 4.819e+02, 9.965e-02, 1.058e-01, 8.005e-02, 
                   3.821e-02, 1.925e-01, 6.373e-02, 3.961e-01, 1.044e+00, 2.497e+00, 3.029e+01,
                   6.953e-03, 1.911e-02, 2.701e-02, 1.037e-02, 1.782e-02, 3.586e-03, 1.497e+01,
                   2.464e+01, 9.605e+01, 6.779e+02, 1.426e-01, 2.378e-01, 2.671e-01, 1.015e-01,
                   3.014e-01, 8.750e-02];  // 1次元配列で定義
    const rows = 1;  // データ数
    const cols = 30;  // 特徴量数
    // 期待出力結果: [0.99957539]

    console.log('Input data: ', input);
    let output = model_inference(input, rows, cols);
    console.log('Prediction result: ', output)

}, 1000);

window.onload = loadWasmModule;

localhostで動作確認した結果、pythonと同じリクエストを投げたときに、同じ出力結果が得られました。レスポンス速度も10ms以下になっており、それなりの速度でした。
image.png
ブラウザのWASMの読み込みも(なぜかサイズが少し膨らんでますが)500KB程度でした。
image.png

まとめ

モデルパラメータやリクエスト数などにもよって変わると思いますが、今回のコードでは次のような結果が得られました。WASMからLightGBMのファイルサイズを引いたものがランタイムなどのコードサイズになりますが、こちらも比較的軽量になっていることが分かりました。また、LightGBMのファイルサイズが大きくなればWASMサイズも大きくなる問題もありますが、そこはモデル側で頑張る必要がありそうです。

LightGBMのファイルサイズ WASMのサイズ1 推論速度
258KB 497KB 10ms

所感

ブラウザ上でLightGBMを利用する際に、onnxruntime-webなどの外部ランタイムを利用して動作させる記事は散見されますが、(JSとWASMのグルーコードを除いて)依存がほとんどWASMに閉じた形での実装例は個人的にあまり見かけなかったので今回紹介しました。刺さる人にしか刺さらない記事になってますが、誰かの参考になれば幸いです。ここまで読んで頂きありがとうございました。

後書き

余談ですが、初めはLightGBMをONNX形式に変換してRustでWASMを作ろうとしてました。ONNXの作成はhummingbirdライブラリで、特徴量で数値のみ利用するケースではONNXにできることを確認しました。ただ、カテゴリ特徴量を使うとONNXに変換できないらしく、Issueによると現時点ではサポート外とのことでした (Issueにはonehotエンコーダで数値特徴量に変換を推奨するコメントがありましたが、カーディナリティがそれなりに高い特徴量を扱いたかったので自分のケースでは微妙でした)。こちらについては将来的に対応される可能性はありますが、ONNXにできてもRustのライブラリ側でWASMに正しく変換できるかも怪しく今回はこの方向性は断念しました。ちなみに、数値特徴量だけ利用する場合はtractでWASMまで作れることは確認しましたが、Go言語実装のWASMよりも成果物のサイズが大きく、推論速度が遅かったためボツになりました(推論速度は実行環境によるので例えば並列処理できる環境などでは変わる可能性があります)。tract以外のONNXからWASMを作るRust製のライブラリ(burns, wonnxなど)はいくつか候補にありましたが、LightGBMで使われているONNXのopratorに全て対応できているものがtract以外になく変換できなさそうでした(2023年12月時点)。2
https://github.com/microsoft/hummingbird/issues/245
https://github.com/sonos/tract

  1. 確認していませんが、WASM初期化時にモデルファイルを読み込んでleavesのモデルインスタンスを作るので、オンメモリのサイズは更に大きくなると思います

  2. PyTorchなどで作れるNN系のサポートが早く、LightGBMなどの決定木系は後回しにされてる印象を受けました

12
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
0