TypeScript でも AI を動かしたい!
TypeScript/JavScript で画像系 DL を動かした時の備忘録です.
モチベーション
- TypeScript の勉強中
- 文法はなんとなく理解したので,何か動くものを作りたい
- 得意分野の Deep Learning 系で何かやってみる
モデルについて
この記事の主旨は TypeScript 側の実装なので,モデルの話は軽くサラッと触れるだけにしておきます.
項目 | 説明 |
---|---|
タスク | セマンティックセグメンテーション |
クラス | キリン・シマウマ |
学習 | pytorch |
推論 | ONNX |
上記の設定でモデル作成を行いました.
また,TypeScript 側での実装をシンプルにするために,前処理・後処理はできるだけ ONNX モデル内部に埋め込んでいます.
具体的には以下の処理を ONNX に組み込み済みです:
-
RGBA 入力対応(ブラウザの
canvas
に合わせるため) - Normalize
-
Sigmoid
など
環境
前提条件
下記の環境を想定しています.
- node v20.19.1
- npm 11.3.0
- npx 11.3.0
- vscode 拡張機能 Live Server
フォルダ構成
./playground/
┣━ node_modules/
┣━ src/
┃ ┣━ image-loader.ts
┃ ┣━ onnx-inference.ts
┃ ┣━ onnx-loader.ts
┃ ┣━ postprocess.ts
┃ ┣━ preprocess.ts
┃ ┗━ utils.ts
┣━ build.js
┣━ index.html # html
┣━ model_rgba.onnx # ONNX モデル
┣━ package-lock.json
┣━ package.json
┗━ tsconfig.json
環境構築
-
ライブラリインストール
$ npm init -y $ npm install typescript onnxruntime-web esbuild --save-dev $ npx tsc --init
-
tsconfig.json 置き換え
{ "compilerOptions": { "target": "ES2020", "module": "ES2022", "moduleResolution": "node", "outDir": "./dist", "rootDir": "./src", "strict": true, "esModuleInterop": true }, "include": ["./src/**/*"] }
-
ビルドスクリプト
// build.js const esbuild = require('esbuild'); esbuild.build({ entryPoints: ['./src/utils.ts'], bundle: true, outfile: './dist/utils.js', format: 'esm', platform: 'browser', target: ['es2022'], sourcemap: false, logLevel: 'info' }).catch(() => process.exit(1)); const fs = require('fs'); fs.copyFileSync( 'node_modules/onnxruntime-web/dist/ort-wasm-simd-threaded.jsep.wasm', 'dist/ort-wasm-simd-threaded.jsep.wasm' );
-
package.json 置き換え
{ "name": "onnx_new", "version": "1.0.0", "description": "", "main": "index.js", "scripts": { "test": "echo \"Error: no test specified\" && exit 1", "build": "node build.js" }, "keywords": [], "author": "", "license": "ISC", "type": "commonjs", "devDependencies": { "esbuild": "^0.25.5", "onnxruntime-web": "^1.22.0", "typescript": "^5.8.3" } }
コード解説
-
html
<!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8"> <meta name="viewport" content="width=device-width, initial-scale=1.0"> <title>ONNX Runtime</title> <script src="./dist/utils.js" type="module"></script> </head> <body> <input type="file" id="imageInput" accept="image/*"> <button id="runBtn">Run Inference</button> <br> <canvas id="imageCanvas" width="512" height="512"></canvas> <canvas id="maskCanvas" width="512" height="512"></canvas> <script type="module"> await window.utils.loadModel(); window.utils.setupImageInput("imageInput", "imageCanvas") document.getElementById("runBtn").addEventListener("click", async () => { const canvas = document.getElementById("imageCanvas"); if (!canvas) return; const mask = await window.utils.runInference(canvas); if (mask){ window.utils.drawMask("maskCanvas", mask); } }) </script> </body> </html>
- canvas
imageCanvas
に元画像を描画 - canvas
maskCanvas
に推論後の画像を描画 - button
runBtn
が押されると推論が開始される
- canvas
-
image-loader.ts
async function loadImageToCanvas( input: HTMLInputElement, canvas: HTMLCanvasElement ): Promise<void> { const file = input.files?.[0]; if (!file) return; const reader = new FileReader(); reader.onload = (e) => { const img = new Image(); img.onload = () => { canvas.width = img.width; canvas.height = img.height; const ctx = canvas.getContext('2d')!; ctx.drawImage(img, 0, 0); }; img.src = e.target?.result as string; }; reader.readAsDataURL(file); } export function setupImageInput(inputId: string, canvasId: string): void { const input = document.getElementById(inputId) as HTMLInputElement; const canvas = document.getElementById(canvasId) as HTMLCanvasElement; if (!input || !canvas) { console.warn('input or canvas not found!'); return; } input.addEventListener('change', () => { loadImageToCanvas(input, canvas); }); console.log('Setup done'); } export function drawMask(canvasId: string, mask: ImageData): void { const canvas = document.getElementById(canvasId) as HTMLCanvasElement; if (!canvas) { console.warn('input or canvas not found!'); return; } canvas.width = mask.width; canvas.height = mask.height; const ctx = canvas.getContext('2d'); if (!ctx) return; ctx.putImageData(mask, 0, 0); }
- 画像周りの utils ファイル
-
onnx-inference.ts
import * as ort from 'onnxruntime-web'; import { getSession } from './onnx-loader.js'; import { postprocessToMask } from './postprocess.js'; import { preprocessCanvas } from './preprocess.js'; function resizeCanvasSize( canvas: HTMLCanvasElement, targetWidth: number, targetHeight: number ): HTMLCanvasElement { const resizedCanvas = document.createElement('canvas'); resizedCanvas.width = targetWidth; resizedCanvas.height = targetHeight; const ctx = resizedCanvas.getContext('2d'); ctx?.drawImage(canvas, 0, 0, targetWidth, targetHeight); return resizedCanvas; } function resizeImage(src: ImageData, targetWidth: number, targetHeight: number): ImageData | null { const srcCanvas = document.createElement('canvas'); srcCanvas.width = src.width; srcCanvas.height = src.height; const srcCtx = srcCanvas.getContext('2d'); srcCtx?.putImageData(src, 0, 0); const dstCanvas = document.createElement('canvas'); dstCanvas.width = targetWidth; dstCanvas.height = targetHeight; const dstCtx = dstCanvas.getContext('2d'); if (!dstCtx) { return null; } dstCtx.drawImage(srcCanvas, 0, 0, targetWidth, targetHeight); console.log(targetWidth, targetHeight); return dstCtx.getImageData(0, 0, targetWidth, targetHeight); } export async function runInference(canvas: HTMLCanvasElement): Promise<ImageData | null> { const session = getSession(); const inputName = session.inputNames[0]; const metadataArray = session.inputMetadata as any[]; const inputMeta = metadataArray[0]; const inputShape = inputMeta.shape; const [_, H, W, __] = inputShape; const originalCanvasWidth = canvas.width; const originalCanvasHeight = canvas.height; const resizedCanvas = resizeCanvasSize(canvas, W, H); const inputData = preprocessCanvas(resizedCanvas); if (!inputData) return null; const inputTensor = new ort.Tensor('float32', inputData, [1, H, W, 4]); const feeds: Record<string, ort.Tensor> = {}; feeds[session.inputNames[0]] = inputTensor; const results = await session.run(feeds); const output = results[session.outputNames[0]]; const mask = postprocessToMask(output); const resizedMask = resizeImage(mask, originalCanvasWidth, originalCanvasHeight); return resizedMask; }
- 推論の核の部分
- 後述する
onnx-loader
,preprocess
,postprocess
の関数を用いて,モデルのロード,前処理,推論,後処理を行う
-
onnx-loader.ts
import * as ort from 'onnxruntime-web'; let session: ort.InferenceSession | null = null; export async function loadModel(modelPath: string = './model_rgba.onnx'): Promise<void> { session = await ort.InferenceSession.create(modelPath, { executionProviders: ['wasm'], }); console.log('ONNX model loaded!'); } export function getSession(): ort.InferenceSession { if (!session) throw new Error('Model not loaded yet'); return session; }
-
model_rgba.onnx
のロードと session の生成を行う
-
-
postprocess.ts
import * as ort from 'onnxruntime-web'; export function postprocessToMask(tensor: ort.Tensor, classIndex: number = 0): ImageData { const [_, C, H, W] = tensor.dims; const data = tensor.data as Float32Array; const imageData = new Uint8ClampedArray(W * H * 4); for (let y: number = 0; y < H; y++) { for (let x: number = 0; x < W; x++) { const i = y * W + x; const v = data[classIndex * H * W + i]; const p = Math.min(255, Math.max(0, v * 255)); imageData[i * 4 + 0] = p; imageData[i * 4 + 1] = p; imageData[i * 4 + 2] = p; imageData[i * 4 + 3] = 255; } } return new ImageData(imageData, W, H); }
- 後処理を行う関数
- モデルは
(1, クラス数, 高さ, 幅)
という shape で出力するので,指定のクラス(classIndex=0 の時シマウマ,classIndex=1の時キリン)の結果に対して後処理を実施 - モデルは sigmoid 後の結果を出力するので,255 倍する
- canvas 描画用に RGBA に変換
-
preprocess.ts
export function preprocessCanvas(canvas: HTMLCanvasElement): Float32Array | null { const ctx = canvas.getContext('2d'); if (!ctx) { return null; } const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); const data = imageData.data; return new Float32Array(data); }
- 読み込んだ画像を float32 の array で返す前処理用の関数
-
utils.ts
import { setupImageInput, drawMask } from './image-loader.js'; import { runInference } from './onnx-inference.js'; import { loadModel, getSession } from './onnx-loader.js'; (window as any).utils = { loadModel, getSession, drawMask, setupImageInput, runInference, };
- html から JavaScript の関数を呼ぶためのコード
今回のコードでは,これらの TypeScript のコードを utils.js
という 1 つのファイルにバンドルしています.
utils.js
を html から読み込むことで,推論が可能になります
ビルド
$ npm run build
上記を実行すると dist/
というフォルダが作成され,2 つのファイルが出力されます.
これで実行の準備は完了です.
デモ
元画像:Photo by Brad Greenlee / CC BY 2.0
※この画像はスクリーンショットにも含まれています.
見事,シマウマのセグメンテーションができました!
おわりに
今回は TypeScript の勉強も兼ねて,TypeScript/JavaScript を用いて onnx モデルの推論を行う方法を解説しました.
読み込んだ画像のデータは RGBA なので扱いづらいところもありますが,前処理・後処理含めて onnx にしておけば結構簡単に実行できることがわかりました.
また何か面白いことをしたら記事にしようと思います.
今回は
onnxruntime-web
などをバンドルして 1 ファイルにまとめています.
実運用の際は,ライブラリのライセンス(MIT など)やモデルの再配布可否にご注意ください.
👉今回使用したモデル含めコードなどは追って GitHub にアップ予定です.
👉 Push しました.リポジトリ