概要
自作のディープラーニングモデルを誰にでも気軽に使ってもらうため、ペライチのHTMLファイルとモデルファイル(ONNX形式)を配布する方法を考えました。
ブラウザで動作している、JavaScriptで推論を実行します。
制作の流れは以下の通りです。
- 好きな言語でモデルを訓練し、ONNX形式で出力する
- HTMLで推論用のページを作る
- JavaScriptでモデル、データ読み込みから推論結果の出力までの処理を作る
例として、画像分類タスク用の配布ファイルを制作します。
ONNX形式のモデルを出力する
今回はPythonのPytorchを使ってモデルを作ります。参考:PyTorchのドキュメント
必要なライブラリは予めインストールしてください。
torchvisionから訓練済みのモデルを借用していますが、自作のモデルでも動くはず。以下はミニマルな例です。
import torch
from torchvision.models import resnet18, ResNet18_Weights
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
export_output = torch.onnx.export(model, dummy_input, 'model.onnx')
dummy_inputは(バッチサイズ,チャンネル数,画像高さ,画像幅)にしています。
ただ、torchvisionのIMAGENETで訓練済みモデルの前処理は、ノーマライズ処理されていることが多いです。
JavaScriptでノーマライズ処理を記述すると、複雑になりがちなので、モデルに含めるといいかもしれません。
import torch
from torchvision.models import resnet18, ResNet18_Weights
from torchvision.transforms import Normalize
model = torch.nn.Sequential(
Normalize((0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
resnet18(weights=ResNet18_Weights.IMAGENET1K_V1),
)
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
export_output = torch.onnx.export(
model, dummy_input, 'model.onnx',
input_names=['modelInput'],
output_names=['modelOutput'])
必要な前処理をSequentialで繋げます。
Sequentialはtorch.nn.Moduleのサブクラスのみ受け付けますが、torchvision.transformsの一部のクラスはModuleのサブクラスですので、JavaScriptではキツめの前処理はモデルに含めるといいかもしれません。
あとinput_namesとoutput_namesを指定しておくと、JavaScript用に後で調べる手間が減ります。
上記ファイルをPythonで実行し、model.onnxファイルを出力しておきます。
これで、配布用のモデルファイルが完成しました。
推論用のページを作る
次に配布用のHTMLファイルを作ります。
この例では推論用として、ONNX Runtime Web、前処理用としてOpenCV.jsを使いますので、headタグ内でそれぞれのライブラリを読み込んでいます。
インターフェースとしては、モデル読み込みと画像読み込みのためボタン、推論結果出力のための場所を用意しています。
以下のファイルをindex.htmlとして保存しておきます。
<!DOCTYPE html>
<html>
<head>
<title>Classification</title>
<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>
<script async src="https://docs.opencv.org/master/opencv.js"></script>
</head>
<body>
<h1>Classification</h1>
<div>
<label for="modelInput">先にONNXk形式のモデルを読み込む</label>
<p><input type="file" id="modelInput" accept=".onnx"></p>
</div>
<div>
<label for="imageInput">画像を読み込むと推論開始 出力まで少し時間がかかる</label>
<p><input type="file" id="imageInput" accept="image/*"></p>
</div>
<p id="result"></p>
<script src="./inference.js"></script>
</body>
</html>
推論用の自作JavaScriptを別ファイルで読み込む記述をしていますが、次節のコードをscriptタグ内にベタ書きしても動きます。
モデル、データ読み込みから推論結果の出力までの処理を作る
最後に、JavaScriptで処理を記述します。
JavaScriptを書かないPythonistaでも理解しやすいように、画像の前処理にOpenCV.jsを使います。(Arrayを直接使うのは可読性が悪いですし、、、)
ポイントはコメントで記載しました。
以下のファイルをinference.jsとして保存しておきます。
// HTMLで作ったインターフェース用のDOMを取得する
const modelInput = document.getElementById('modelInput')
const imageInput = document.getElementById('imageInput');
const result = document.getElementById('result');
let session
// 推論結果から、値が最大のカテゴリーのインデックスを取得する
// これはChatGPTに教えてもらった すごい!けど慣れないから読みにくい
async function getArgMax(tensor) {
const data = new Float32Array(
tensor.data.buffer, tensor.data.byteOffset, tensor.data.length
);
const maxIndex = data.reduce((maxIndex, currentValue, currentIndex, array) => {
return currentValue > array[maxIndex] ? currentIndex : maxIndex;
}, 0);
return maxIndex;
}
// モデルファイルを読み込む
modelInput.addEventListener('change', async (e) => {
const file = e.target.files[0];
session = await ort.InferenceSession.create(await file.arrayBuffer());
});
// 画像の前処理から推論結果の出力まで
imageInput.addEventListener('change', async (e) => {
const file = e.target.files[0];
if (file) {
const image = new Image();
image.src = URL.createObjectURL(file);
image.onload = async () => {
// ここからOpenCVを使う
// OpenCV.jsではimreadはRGBAになるため、注意!
const img = cv.imread(image);
// PythonのOpenCVと記述方法が少し違い、
// 先に変換先の画像データをインスタンス化した後に変換関数に通す
const rgbImage = new cv.Mat();
cv.cvtColor(img, rgbImage, cv.COLOR_RGBA2RGB);
// blobFromImageという便利関数がある
// スケーリングと、サイズ変換、チャンネルごとの減算などができるが、
// スケーリングをチャンネルごとに適用できない
const blob = cv.blobFromImage(
rgbImage, 1/255, new cv.Size(224, 224), new cv.Scalar(0, 0, 0), false, false
);
// ONNX Runtime Web用にデータを変換
const imgTensor = new ort.Tensor('float32', blob.data32F, [1, 3, 224, 224]);
// OpenCVで作ったデータは最後にdeleteするのが作法らしい
img.delete();
rgbImage.delete();
blob.delete();
// ONNXモデルに渡して推論実行
// ここで、Python側で指定した入出力のnameを指定
const inputs = {'modelInput': imgTensor};
const output = await session.run(inputs);
idx = await getArgMax(output['modelOutput']);
// 結果を出力(カテゴリーのインデックスのみ)
result.innerText = idx;
};
}
});
できたものの確認
これで一通りの配布ファイルができました。
上記3つのファイル(またはjsファイルはHTML内に書いておくと、2つのファイル)を使って推論を実行します。
index.htmlファイルとinference.jsを同じフォルダに置いて、index.htmlファイルを開くと以下のように表示されると思います。
まず上のボタンを押して、pythonで作ったモデルファイルmodel.onnxを読み込みます。
次に、下のボタンを押して画像を読み込みます。今回はモモンガちゃんに登場してもらいます。
結果は151 なんのカテゴリーでしょうか。このサイトで答え合わせです。
うーん、チワワ。おしい。
調べるとIMAGENETにはモモンガがないようです、残念。
終わりに
ちょっとUIが雑でしたが、制作のイメージが伝わっていれば嬉しいです。
せっかく訓練した最高のモデルなので、たくさん使われるといいですね!