LoginSignup
58
41

More than 5 years have passed since last update.

ONNX.jsを使ってWebブラウザでディープラーニング

Last updated at Posted at 2018-12-24

はじめに

ディープラーニングの世界には様々なフレームワークがあり、Caffeで学習されたモデルはCaffeでしか推論できませんでしたが、CaffeモデルをKerasに変換したりとそれぞれ独自に変換するツールもでてきました。そうした中で、異なるフレームワーク間で同じ学習モデルをより便利に使うために、ニューラルネットワークのフォーマットの標準化が進んでいます。

この記事では、ニューラルネットワークの標準フォーマットの1つであるONNXと、ONNXモデルをブラウザで動作させることができるONNX.jsについて紹介したいと思います。

ONNX

ONNX.png

ONNX (Open Neural Network Exchange) は、FacebookとMicrosoftが提唱しているニューラルネットワークのモデル表現の標準フォーマットです。近年、AmazonやPFNもこのプロジェクトに参画したようです。

ONNX以外にも、NNEF (Neural Network Exchange Format)というものあり、こちらはKhronosグループが主導していますが、あまり活発ではないようです。

Caffe2, CNTK, MXNet, PyTorch, ChainerなどのフレームワークがONNXをサポートしており、TensorflowやCoreMLなどでも使用できます。

Caffe2で学習したモデルをONNXモデルに変換して、ONNXモデルをCNTKで読み込み推論するといったことができるようになります。ただし、PyTorchのようにExportのみをサポートしていたり、Chainerのように標準には組み込まれておらず別途インストール1する必要があったりするので注意が必要です。

各フレームワークごとのサポート状況やインストール方法については、公式チュートリアルを参照すると良いでしょう。

ONNX Model Zoo

ONNX形式の学習済みモデルは、ONNX Model Zooにまとめられているので、そのまま使ったりファインチューニングするのに使ったりする際に便利です。

まだ数は少ないですが、画像系のメジャーなモデルはだいたいあるようです。

  • 画像分類: MobileNet, ResNet, VGG, ...
  • 物体検知: Tiny_YOLOv2, SSD, Faster-RCNN, ...

ONNX.js

Tensorflowで学習されたモデルをブラウザで実行することができるTensorflow.jsとライブラリがありますが、ONNX.jsでは、ONNX形式のモデルを使ってブラウザで実行することができます。

デモサイトでは、画像分類やセグメンテーションなどのいくつかのデモを試すことができます。

onnxjs.png

ちなみに、Tensorflow.jsよりも高速らしいです。

perf-resnet50.png
出典: https://github.com/Microsoft/onnxjs#benchmarks

また、ONNXと異なりMicorsoftのリポジトリにあり、Micorsoftが開発の中心のようです。そのためか、MacやChromeよりも、WindowsやEdgeへの対応が優先されているようです。。

まだ開発中で対応していないOSやブラウザもあるので、対応状況は確認しておくと良いです。現時点2で、iOSはすべてのブラウザでComing soonとなっています。

ONNXモデルの準備

それでは、実際にONNXモデルを作って、ONNX.jsを動かしてみたいと思います。
今回は、PyTorchを使ってONNXのモデルを作りたいと思います。なお、バージョンは1.0.0を用いています。(古いバージョンだとnn.onnxが含まれていないです。)

また、手書き文字データのMNISTの学習モデルを作ることにします。

ネットワークの定義

畳み込み層2層と全結合層2層で構成されるシンプルなネットワークを定義します。

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

サポートされるオペレーター

ここで注意すべきは、ONNXやONNX.jsで使用できるオペレーター(ConvとかReLu)に制限があることです。厄介なことに、ONNXで使えるオペレーターすべてがONNX.jsで使えるわけではありません。

はじめ、F.log_softmaxを使っておりONNX形式でのExportは成功しても、ONNX.jsでImportでエラーが発生して読み込めませんでした。。

使えるオペレーターは、以下のページで確認できます。

ONNX.jsでは、Backend(CPUやGPU)によっても使えるオペレーターが変わるようです。厄介ですね。

さきほどのLogSoftmaxのように基本的なオペレーターもまだ対応していなかったりするので、まだまだこれからな感じです。

モデルの学習

通常のPyTorchの学習と同じで、公式のExampleをベースに少し変更しています。(ネットワークの定義でF.log_softmaxが使えなかったのでlossの計算でcross_entropyを用いています)

import torch
from torchvision import datasets, transforms

# 使用するデバイスの設定(GPUが使える場合はGPUを使い、使えない場合はCPUを用いる)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# モデルの読み込み
model = Net().to(device)

# データセットの準備
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                   ])), batch_size=args.batch_size, shuffle=True)

# オプティマイザーの準備
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

# モデルの学習
def train():
   model.train()
   for _, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

epochs = 10
for _ in range(epochs):
   train(model, train_loader, optimizer)

ちなみに、10epochで1分くらいかかり、テストデータに対するAccuracyは0.9883でした。

ONNXモデルへのExport

ONNXは構築された計算グラフを保存するものなので、PyTorchのようにDefine-by-Runのフレームワークの場合、入力データの形式をダミーデータを用いて指定します。値はランダムでもすべて0でも問題ないです。

ダミーデータを用意すれば、Exportはtorch.onnx.exportを用いるだけで簡単に行うことができます。

# 1x1x28x28 のダミーデータを用意する
dummy_input = torch.randn(1, 1, 28, 28, device=device)

# 学習済みモデルとダミーデータを用いてONNX形式のモデルをファイルに出力
torch.onnx.export(model, dummy_input, 'mnist.onnx')

モデルの中身を見てみる

出力されたファイルは、バイナリデータなので中身を見ることができません。そこで、onnxを使います。

$ pip install onnx

その他のインストール方法はGithubを参照ください。

import onnx
import onnx.helper

model = onnx.load('mnist.onnx')
print(onnx.helper.printable_graph(model.graph))

以下のような出力を得ることができます。

graph torch-jit-export (
  %0[FLOAT, 1x1x28x28]
) initializers (
  %1[FLOAT, 20x1x5x5]
  %2[FLOAT, 20]
  %3[FLOAT, 50x20x5x5]
  %4[FLOAT, 50]
  %5[FLOAT, 500x800]
  %6[FLOAT, 500]
  %7[FLOAT, 10x500]
  %8[FLOAT, 10]
) {
  %9 = Conv[dilations = [1, 1], group = 1, kernel_shape = [5, 5], pads = [0, 0, 0, 0], strides = [1, 1]](%0, %1, %2)
  %10 = Relu(%9)
  %11 = MaxPool[kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]](%10)
  %12 = Conv[dilations = [1, 1], group = 1, kernel_shape = [5, 5], pads = [0, 0, 0, 0], strides = [1, 1]](%11, %3, %4)
  %13 = Relu(%12)
  %14 = MaxPool[kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]](%13)
  %15 = Constant[value = <Tensor>]()
  %16 = Reshape(%14, %15)
  %17 = Gemm[alpha = 1, beta = 1, transB = 1](%16, %5, %6)
  %18 = Relu(%17)
  %19 = Gemm[alpha = 1, beta = 1, transB = 1](%18, %7, %8)
  return %19
}

Netronを使って可視化する

Netronを使うと手軽に可視化できます。Web版の他、macOSアプリなどもあります。

さきほど紹介したONNX Model Zooから適当にONNX形式のモデルをダウンロードすると、すぐに試すことができます。

mnist_cnn.png

NetronはONNX専用というわけではなく、複数の形式をサポートしているので、普段の開発/学習でもWebでさくっと確認するときには良いと思います。

フレームワーク モデルファイルの拡張子 サポート
ONNX .onnx, .pb, .pbtxt
Keras .h5, .keras
CoreML .mlmodel
Caffe2 predict_net.pb, predict_net.pbtxt
MXNet .model, -symbol.json
TensorFlow Lite .tflite
Caffe .caffemodel, .prototxt
PyTorch .pth
Torch .t7
CNTK .model, .cntk
PaddlePaddle __model__
Darknet .cfg
scikit-learn .pkl
TensorFlow.js model.json, .pb
TensorFlow .pb, .meta, .pbtxt

※サポーtのexperimental supportと記載があったフレームワークです。もしかしたらうまく動かないものもあるかもしれません。

WebブラウザでONNXモデルのロード

ONNX.jsを使う方法として、<script>タグを使う方法もありますが、ReactやVueといった最近のWebフレームワークと相性が良いので、npmで動かす方法を試してみます。

$ npm install onnxjs
import { Tensor, InferenceSession } from 'onnxjs'

async runModel() {
    // 推論に用いるセッションの初期化
    // Backendには cpu や webgl, wasm を利用することができます
    const session = new InferenceSession({ backendHint: 'webgl' })

    // ONNX形式のモデルファイル
    const modelFile = './mnist.onnx'

    // モデルの読み込み
    await session.loadModel(modelFile)

    // Inputデータの準備(必要であれば事前に前処理をしておく必要があります)
    // とりあえず、すべて0とするダミーデータを用意します
    const dummy_input = new Float32Array(28 * 28).fill(0)
    const inputTensor = new Tensor(dummy_input, 'float32', [1, 1, 28, 28])

    // 推論の実行
    const outputData = await session.run([inputTensor])
}

Canvasからデータ取得する

ダミーデータでは面白くないので、Canvasに文字を書いて読み取ることを考えます。

input-canvasinput-canvas-scaledの2つのCanvasを用意しておき、input-canvas-scaledの方は28x28のサイズにしておきdisplay: 'none'で表示させないようにしておきます。

getImageTensor() {
    // input-canvasのcontextを取得
    const ctx = document.getElementById('input-canvas').getContext('2d')

     // input-canvasのデータを28x28に変換して、input-canvas-scaledに書き込む
    const ctxScaled = document.getElementById('input-canvas-scaled').getContext('2d')
    ctxScaled.save();
    ctxScaled.scale(28 / ctx.canvas.width, 28 / ctx.canvas.height)
    ctxScaled.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height)
    ctxScaled.drawImage(document.getElementById('input-canvas'), 0, 0)
    ctxScaled.restore()

     // input-canvas-scaledのデータをTensorに変換
    const imageDataScaled = ctxScaled.getImageData(0, 0, 28, 28)
    // console.log('imageDataScaled', imageDataScaled)

     const input = new Float32Array(784);
    for (let i = 0, len = imageDataScaled.data.length; i < len; i += 4) {
        input[i / 4] = imageDataScaled.data[i + 3] / 255;
    }
    const tensor = new Tensor(input, 'float32', [1, 1, 28, 28]);

    return tensor
}

先程のrunModelでダミーデータを用いていたところをgetImageTensorを使えば、Canvasデータを推論にかけることができます。

以下は、これらを用いて実装したデモです。

onnxjs.gif

おわりに

複数のディープラーニングフレームワークで共通で使える標準フォーマットであるONNXと、ONNXモデルをブラウザで動作させるONNX.jsについて紹介しました。

ONNX.jsについてはまだ発表されたばかりで、まだドキュメント(特に日本語)が少なかったり、機能面でもまだまだな感じがしますが、Exampleは意外とちゃんとしてたり、ONNX自体も対応フレームワークも増えデファクトスタンダードになりつつあるので、ONNX.jsのこれからの開発も期待したいところです。


  1. Chainerの場合onnx-chainerをインストールするとONNXを使うことができます 

  2. 2018.12.24現在 

58
41
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
58
41