LoginSignup
3
0

ONNXRuntime-Webを使ってブラウザ拡張でニューラルモデルをローカル推論しよう

Last updated at Posted at 2023-12-15

はじめに

これはKaggle Advent Calendar 2023の16日目の記事です.

みなさん,私生活でニューラルモデル使ってますか?DeepLとかChatGPTとかではなく,皆さんが魂を込めた手元のモデルの話です.
最近のkaggleコンペの題材はかなり一般的なアプリケーションにそのまま転用できそうな話題も増えているように感じます.自分で作ったものに限らなくても,前処理だったりでネットから拾ってきたモデルも便利だったりしますよね.それ,コンペ終わったら即ポイじゃもったいなくないですか?

そこでonnxruntime-webの出番です.onnxruntime-webを使うとブラウザでニューラルモデルのローカル推論ができます.この記事では特にブラウザの拡張機能でモデルを使う方法を見ていきます.ブラウザ拡張の開発はそれ自体すでに万人にお勧めしたいものですが,そこにさらにニューラルモデルが使えるようになればまさに鬼に金棒.あれもうちょっとなんとかなるんちゃうか!?ってストレスをどんどんDeepでポンできます.

前提条件

開発にはNode.jsが必要になります.今回はNode20.10.0で試しています.以下で出てくるコードを実際に動かしたい場合は最新のNode.jsを公式サイトに従ってインストールしておいてください.

とにかく動くやつ

なにはともあれ動作しているものが見たい人は以下のレポジトリのUsage通りにやればとりあえずニューラルモデルが内部で動いてる拡張機能をビルドできます.
動くのを見なくても構わない人はそのまま次の章に行ってください.
https://github.com/yufuin/onnxruntime-web-on-extension/

ビルドしたら,chromeの拡張機能の管理画面を開いて,デベロッパーモードをon➡左上のほうにある"パッケージ化されていない拡張機能を読み込む"でビルドしたdistディレクトリを開くと,新しい拡張機能のパネルが増えるので,そのパネルのservice_workerを見れば動いているのが確認できます.

FireFoxはmanifest.jsonがChromeと違うので動きません.書き換えるかChromeを使ってください.

実装の流れ

拡張機能,というかブラウザ上でニューラルモデルのフォワードを実行するのは,つまるところonnxruntime-webライブラリを使うだけです.実装の大きな流れは以下の通りです.

① PyTorchなりでモデルを設計・学習
② 学習したモデルをONNX or ORT形式でエクスポート
③ ブラウザ拡張でonnxruntime-webライブラリを用いてモデルをロードしローカルで推論を実行

①はこの記事では深入りしないことにします.
②はネットにたくさんやり方書いてあるのですぐできる・・・と思いきやonnxruntime-web特有の制限があって意外に一筋縄ではいきません.
③はやること自体はシンプルなのですが,ここにもハマリポイントがちょっとだけあります.

予備知識: onnxruntime-web特有の制限について

実装に入る前に予備知識です.このあとの実装に関わってきます.

ONNXはモデルファイルフォーマットの一種で,メジャーどころの機械学習ライブラリのモデルならONNX形式へのエクスポート手段が提供されています(pytorchの例).

Open Neural Network Exchange(略称:ONNX)とは、オープンソースで開発されている機械学習や人工知能のモデルを表現する為の代表的なフォーマットである。 実行エンジンとしてONNX Runtimeも開発されている。(Wikipedia

onnxruntime-webはその名の通りwebブラウザ向けのONNXランタイムで,JavaScript+WebAssemblyで作られており,これを使えばあとはモデルをONNX形式で用意することでブラウザ上で動かせるというわけです.

ONNX形式でエクスポートするときの大きな注意点として,変換元のライブラリ(PyTorchとか)のあらゆる演算がサポートされるわけではなく,使えないものもあります.この制限に加えて,onnxruntime-webを使う場合は注意しなければならないことが増えます

前提として,onnxruntime-webには実行時のバックエンドとして,CPU実行するwasm,GPU実行するwebglおよびwebgpuの三つの実行オプションがあります.各バックエンドごとに制限が異なるため,どのバックエンドを利用するかあらかじめ想定して実装を進める必要があります.2023年12月時点の各バックエンドごとの大まかな特徴は以下の通りになります.

バックエンド wasm webgl webgpu
実行 CPU GPU GPU
利用可能な演算 基本制限なし 使えない演算がかなり多い 少し制限有り
float16の利用 不可
主要PC向けブラウザのデフォルト設定での対応状況 全て対応 全て対応 Chrome (Linux以外), Edge, Opera
service_worker上での利用 不可 (offscreen経由などで一応回避可能)
content_script上での利用
巨大モデルの利用 不可 不可 不可

使えない演算があるとは,例えばwebglではRNNやLSTM, ArgMaxなどが使えません.webgpuは演算自体は基本的に全て使用可能ですが,Attentionのmaskなど一部未実装の機能があるようです.詳しくは公式GitHubレポのGPUサポート状況を参照してください.
なお,Linux版Chromeでもブラウザ側で設定すればwebgpuは利用可能で,FireFoxも設定すれば全OSで利用可能です (参考: https://caniuse.com/webgpu ).この設定のハードルが一般利用ではかなり高いのですが.
また,WebAssembly自体のメモリ4GB制限のために現在のところどのバックエンドであってもLLMのような巨大モデルは利用できません.この点に関しては64bitサポートを待ちましょう (参考: GitHub Issue 1GitHub Issue 2).
service_workerでのwebgpuの利用不可についても,利用可能にしようとする動きがあるようで,待っていれば解決しそうです.(参考: Google GroupスレGitHub Issue)

大雑把に纏めてしまうと,他の人に使ってもらうつもりで作るなら2023年12月時点ではCPU前提にした方が無難で,GPUを使う場合は,webglでもwebgpuでも実装の苦労を覚悟することになります.
webglでは速度が欲しくてGPUを使うはずなのにFloat16が非サポート&使えない演算が結構ある点がかなり痛く,一方webgpuは一般利用を考えると対応ブラウザ的に実質Chrome限定になるうえ,service_worker上での利用不可のせいで実装が面倒になるケースが多いです.
将来的にはwebgpuの問題が解決され,特に苦労なく初手webgpuでいい状況になるとは思われますが,現時点ではwebgpuはまだ新しい技術で実験的な機能というポジションになっており,積極的な利用はもうしばらく待つことになりそうです.

なおCPUやwebgpuでの実行であっても,JavaScriptにはネイティブなFloat16がなく現状はUint16をバッファとして代用しており,float16が入出力になっている場合は自力でビット列をエンコード/デコードすることになります.

実装手順②: モデルをONNX or ORT形式でエクスポート

前述の通りメジャーなライブラリではONNXへのエクスポート手段が提供されています.例えばPyTorchでは以下のようにしてできます.

export_model.py
import torch

 # 入力:2次元 -> 出力:3次元
 class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(2,3)
    def forward(self, x):
        h = self.fc1(x)
        return h

model = Model()
# NOTE: 本来はここでmodel を学習したり事前学習済みの重みをロードする.
# ここではサンプルとして入力2次元そのまま&合計の3次元で出力する重みをartificialにセットした
model.fc1.bias.data = torch.zeros([3])
model.fc1.weight.data = torch.FloatTensor([
    [1.0, 0.0], # first elem
    [0.0, 1.0], # second elem
    [1.0, 1.0], # sum of elems
])

# 想定する入力のサンプルが必要.dynamic_axisを利用しないならバッチサイズなど本来可変であるべきパラメータがこのサンプル入力のshapeに固定されるので注意.
sample_input = {"x":torch.FloatTensor([[3.0, 2.0]])}

SAVE_PATH = "/path/to/model.onnx"
torch.onnx.export(
    model,
    sample_input,
    SAVE_PATH,
    input_names=["x"], # sample_inputのキーに合わせる
    output_names=["y"], # 自由.onnxruntimeの利用時にこの名前を覚えておく必要あり.
    dynamic_axes={
        "x": {0:"batch_size"}, # dynamic_axisを利用
    },
)

なおPyTorchは上の方法でエクスポートするとバックエンドにwebglを利用する場合dynamic_axisがバグで動作しません.これはORT形式に変換することで回避することができるほか,エクスポート手段によっては問題なく利用できることもあります.ORT形式への変換は以下のコマンドで実行できます.

python -m onnxruntime.tools.convert_onnx_models_to_ort /path/to/model.onnx
# 生成される/path/to/model.ortをmodel.onnxの代わりに使う

また,モデルがtransformersによるものの場合は,huggingfaceがホストしているoptimumライブラリによるエクスポートを利用した方がよいでしょう.
https://github.com/huggingface/optimum

実装手順③: ブラウザ拡張でonnxruntime-webライブラリを用いてモデルをロード

ブラウザ拡張のスクリプトにはいくつか種類がありますが,その中でもメインの処理を担当する,つまりonnxruntime-webでの推論を実行しうるのは以下の二種類です.

  1. service_worker
    ブラウザ本体のバックグラウンドで動くスクリプトで,拡張機能のホーム的存在
  2. content_scripts
    各webページにinjectされ,そのページの一部として動くスクリプト

どちらでもonnxruntime-webによる推論の実行は可能ですが,多くの拡張機能ではservice_workerで推論を行うことになるはずです.content_scriptsは各ページの一部として動くため,content_scripts上でモデルをロードするとページをたくさん開けばその分同じモデルをメモリ上に複数ロードしてメモリドカ食い気絶してしまうからです.
メモリドカ食いを避けるために,基本的にはcontent_scriptsは各ウェブページからモデルの入力に必要な情報を集めてsendMessageでservice_workerと対話するようにして,実際の推論はservice_workerで行うことになるでしょう.

service_worker上での推論実行

service_worker上でのモデルの推論は一般サイト向け公式のサンプルコードだいたいそのままで動きます.

src/service_worker.ts
import * as ort from "onnxruntime-web";

// URL.createObjectURLの抑制のためにシングルスレッドの設定が必須.ref: https://github.com/microsoft/onnxruntime/issues/14445
ort.env.wasm.numThreads = 1;

// パッケージングする拡張機能上のwasmファイルの配置に合わせてwasmPathsを設定する.後述するvite.config.mtsの設定なら拡張機能パッケージのルートに配置され,デフォルト設定もルート ("./") なので今回は設定不要
// ort.env.wasm.wasmPaths = "./";

// モデルファイルは拡張機能のパッケージに含めず外部からダウンロードすることも可能.今回はパッケージに含めることにした.
// 外部を指定する場合は直接URLを指定できるほか,別途fetchしてArrayBufferを渡す方法もある.
const MODEL_PATH = "./model.onnx"

async function test_ort_service_worker() {
    // setup session with "wasm" (CPU) backend
    const session = await ort.InferenceSession.create(MODEL_PATH, {executionProviders: ["wasm"]});
    // prepare input
    const batch_size = 1;
    const input_dim = 2;
    const input_tensor_data = new Float32Array([2.5, 4.25]); // the data buffer is a flattened tensor (shape=[1,2] => num_elems=[1*2]=[2]).
    const input_tensor = new ort.Tensor("float32", input_tensor_data, [batch_size, input_dim]);
    // forward
    // input and output names (here `x` and `y`) depend on the model definition.
    const feeds = { x: input_tensor };
    const results = await session.run(feeds);
    const output_tensor = results.y;
    console.log(`flattened output tensor: [${output_tensor.data}]`); // should be [2.5, 4.25, 6.75]
    console.log(`- original shape=[${output_tensor.dims}]`);
}
test_ort_service_worker();
src/public/manifest.json
{
    ...,
    "content_security_policy": {
        "extension_pages": "script-src 'self' 'wasm-unsafe-eval'; object-src 'self'"
    },
    "background": {
        "service_worker": "service_worker.js",
        "type": "module"
    },
    ...
}
vite.config.mts
import { resolve } from 'node:path';
import { defineConfig } from 'vite';
import { viteStaticCopy } from 'vite-plugin-static-copy'

export default defineConfig({
    root: 'src',
    build: {
        outDir: '../dist',
        rollupOptions: {
            input: {
                service_worker: resolve(__dirname, 'src/service_worker.ts'),
                content_scripts: resolve(__dirname, 'src/content_scripts.ts'),
            },
            output: { entryFileNames: '[name].js' },
        },
    },
    plugins: [
        viteStaticCopy({targets: [
            {
                src: '../node_modules/onnxruntime-web/dist/*.wasm',
                dest: '.',
            },
        ]}),
    ],
});

注意点として以下の2点があります.

  • WebAssemblyのwasmファイルを拡張機能の配布パッケージに含めて(vite.config.mts中のviteStaticCopyの部分),そのwasmファイルの配置に合わせてservice_workerスクリプト内でort.env.wasm.wasmPathsを設定する
  • InferenceSession.createをする前にort.env.wasm.numThreads = 1をしておく

ort.env.wasm.wasmPathsはデフォルト設定ではパッケージのルート想定になっているので,wasmファイルを拡張機能のルートディレクトリに配置するならort.env.wasm.wasmPathsの設定は不要です.
numThreadsの設定を行わないといけない理由は,onnxruntime-webはデフォルトだとマルチスレッドでセットアップする設定になっているが,この設定ではセットアップ中にManifest v3のservice_worker上では許されないURL.createObjectURLの呼び出しを行ってしまうためのようです.後から気付いたのですがこれtransformers.jsの公式サンプルコードにはその注意書きが書いてあるんですよね.なぜonnxruntime-webには書いてないのか・・・

service_workerでバックエンドにwebgpuを使うには?

前述の通り,2023年12月現在ではservice_workerで使用可能なバックエンドはwasmとwebglのみで,webgpuは使用不可です.
一応service_worker上からoffscreen (参考: chrome公式ドキュメント) を利用したりして作業用のページを開き,そのページでモデルをロード・推論するようにすることで疑似的にwebgpuをservice_workerから利用可能にする手段がある・・・のですが,

  • Manifest v3におけるservice_workerはしばらく何もしないと勝手に死ぬくせにoffscreenだったりのページはちゃんと閉じないと開きっぱなしになる
  • メモリ使用量が巨大になりかねないニューラルモデルをずっとメモリ上にロードしっぱなしというわけにはいかないので開いたら閉じる実装にせざるをえず面倒
  • 開いたページとのservice_workerとのやりとりはmessage passingで面倒
  • Manifest v3ではservice_workerが閉じたかどうかの判定が面倒
  • offscreen,FireFox (愛ブラウザ) で使えなくね・・・?

と,手間がかかるので自分はやりたくないなと思いました.時間が解決してくれそうなので自分はwebgpuが手間なく使えるようになるまで待ちます.

content_scripts上での推論実行

ごく一部のサイトに絞った機能を提供するような拡張機能であれば,content_scriptsでモデルをロードすることも現実的です.ただしcontent_scriptsではdynamic importによってonnxruntime-webをロードするようにし,また実行時にwasmファイルとモデルファイルを拡張機能側から取ってこなければならないため,対応してmanifest.jsonも書き換える必要があります.

src/content_scripts.ts
async function test_ort_content_scripts() {
    // dynamic import
    // NOTE: バックエンドにwebgpuを利用したいなら"onnxruntime-web"ではなく"onnxruntime-web/webgpu"をインポート
    const ort = await import("onnxruntime-web/webgpu");

     // content_scripts上での推論時はURL.createObjectURLの使用が適正のためシングルスレッドである必要はない
     // ort.env.wasm.numThreads = 1;

     // 一方,wasmを拡張機能のパッケージ側から取ってくるための設定が必須
    ort.env.wasm.wasmPaths = chrome.runtime.getURL("./");

    // ローカルに配置したモデルのパスもchrome.runtime.getURL経由になる
    const MODEL_PATH = chrome.runtime.getURL("./model.onnx");
    // バックエンドに"webgpu"を設定.利用できない場合は後ろの"wasm"になる.
    const session = await ort.InferenceSession.create(MODEL_PATH, {executionProviders: ["webgpu", "wasm"]});

    // 以降はservice_workerと同じ
    // prepare input
    const batch_size = 1;
    const input_dim = 2;
    const input_tensor_data = new Float32Array([2.5, 4.25]); // the data buffer is a flattened tensor (shape=[1,2] => num_elems=[1*2]=[2]).
    const input_tensor = new ort.Tensor("float32", input_tensor_data, [batch_size, input_dim]);
    // forward
    const feeds = { x: input_tensor };
    const results = await session.run(feeds);
    const output_tensor = results.y;
    console.log(`flattened output tensor: [${output_tensor.data}] (original shape=[${output_tensor.dims}])`);
}
test_ort_content_scripts();
src/public/manifest.json
{
    ...,
    "content_scripts": [
        {
            "matches": [ "https://example.com/*" ],
            "js": [ "content_scripts.js" ]
        }
    ],
    "web_accessible_resources": [
        {
            "matches": [ "https://example.com/*" ],
            "resources": [
                "assets/*",
                "ort-wasm*.wasm",
                "model.onnx"
            ]
        }
    ],
}

2023年12月時点におけるcontent_scripts上での実行の最大のメリットはバックエンドにwebgpuが面倒な実装なく使えることです.
web_accessible_resourcesはcontent_scripts上から拡張機能ローカルのファイルにアクセスするために必要です.ここではassets/*はviteによってパッケージングされるonnxruntime-webライブラリのために,ort-wasm*.wasmはonnxruntime-webのwasmファイルのために用意しています.

補足: 対抗馬のtransformers.jsについて

ちなみに,利用したいモデルがhuggingface transformersのものなら,onnxruntime-webの代わりにtransformers.jsを使う選択肢があります.
自分で使ったことはありませんが,公式ドキュメントによるとpipelineが使えたりtokenizerもかなり楽に書けるようです.
transformers.jsも内部的にはonnxruntimeを使っており,サンプルコードを見る分には使用感もかなり近いようです.モデルファイルがONNXである必要があるのも両者共通です.ただしtransformers.jsは現状はCPUのみ (参考: GitHub Issue) のようです(onnxruntime-webはGPUも利用可).
また,transformers.jsのquantizeスクリプトはonnxruntime==1.15.1に下げないと動かない(参考: GitHub Issue) とのことです.

おわりに

現状onnxruntime-webはなんでもポンかというとそうでもないのが正直なところですが,それでもブラウザ上でローカル推論できるのは夢が広がるのは間違いありません.せっかくkaggleですごいAIを作る・使う術を学んでいるはずなのですから,仕事でも私生活でもどんどん使っていきたいところです.

自作ブラウザ拡張を作るのは実際にQoLが爆上がりするので本当におすすめです.
自分も一つ拡張機能を作ってみてまあまあ満足のいくものができました.この記事までには公開の審査が間に合わなかった&最終的にニューラルモデルを使わないほうが自分の目的では体験がよかったという元も子もない結果になってしまったのでここでは紹介しないことにしましたが,気分が乗った時に別記事にでもしようかと思います.
ちなみに自分はJavaScriptは本当に気まぐれconsole.log以外で触ったことがなかったんですが,実装はChatGPTとCopilotが重要な処理やらHTMLのレイアウトやらガチのマジでほぼ全部やってくれました.これが新時代・・・

原理的にはブラウザ上でLLMも使えることになりますが,現実的には一般的なPCのリソースでLLMを動かすのは無謀だし,例えばページ内をLLMでふわっと検索みたいな拡張機能は自分専用に作るならともかく一般公開まではまだきつそうです.もっと軽量モデルが増えるといいですね. ➡ 2023/12/31追記: WebAssemblyのメモリ制限のためLargeなLMは現時点では利用不可でした.

今年はリアルがあまりにも鬼すぎてとうとう一回もコンペに参加できませんでした.LLMコンペとかでXで楽しそうなエックセズしてる人を見て悶々とする毎日でしたが,来年はGM目指して頑張りたいところです.

3
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
0