はじめに
ディープラーニングの世界には様々なフレームワークがあり、Caffeで学習されたモデルはCaffeでしか推論できませんでしたが、CaffeモデルをKerasに変換したりとそれぞれ独自に変換するツールもでてきました。そうした中で、異なるフレームワーク間で同じ学習モデルをより便利に使うために、ニューラルネットワークのフォーマットの標準化が進んでいます。
この記事では、ニューラルネットワークの標準フォーマットの1つであるONNXと、ONNXモデルをブラウザで動作させることができるONNX.jsについて紹介したいと思います。
ONNX
ONNX (Open Neural Network Exchange) は、FacebookとMicrosoftが提唱しているニューラルネットワークのモデル表現の標準フォーマットです。近年、AmazonやPFNもこのプロジェクトに参画したようです。
ONNX以外にも、NNEF (Neural Network Exchange Format)というものあり、こちらはKhronosグループが主導していますが、あまり活発ではないようです。
Caffe2, CNTK, MXNet, PyTorch, ChainerなどのフレームワークがONNXをサポートしており、TensorflowやCoreMLなどでも使用できます。
Caffe2で学習したモデルをONNXモデルに変換して、ONNXモデルをCNTKで読み込み推論するといったことができるようになります。ただし、PyTorchのようにExportのみをサポートしていたり、Chainerのように標準には組み込まれておらず別途インストール1する必要があったりするので注意が必要です。
各フレームワークごとのサポート状況やインストール方法については、公式チュートリアルを参照すると良いでしょう。
ONNX Model Zoo
ONNX形式の学習済みモデルは、ONNX Model Zooにまとめられているので、そのまま使ったりファインチューニングするのに使ったりする際に便利です。
まだ数は少ないですが、画像系のメジャーなモデルはだいたいあるようです。
- 画像分類: MobileNet, ResNet, VGG, ...
- 物体検知: Tiny_YOLOv2, SSD, Faster-RCNN, ...
ONNX.js
Tensorflowで学習されたモデルをブラウザで実行することができるTensorflow.jsとライブラリがありますが、ONNX.jsでは、ONNX形式のモデルを使ってブラウザで実行することができます。
デモサイトでは、画像分類やセグメンテーションなどのいくつかのデモを試すことができます。
ちなみに、Tensorflow.jsよりも高速らしいです。
出典: https://github.com/Microsoft/onnxjs#benchmarks
また、ONNXと異なりMicorsoftのリポジトリにあり、Micorsoftが開発の中心のようです。そのためか、MacやChromeよりも、WindowsやEdgeへの対応が優先されているようです。。
まだ開発中で対応していないOSやブラウザもあるので、対応状況は確認しておくと良いです。現時点2で、iOSはすべてのブラウザでComing soon
となっています。
ONNXモデルの準備
それでは、実際にONNXモデルを作って、ONNX.jsを動かしてみたいと思います。
今回は、PyTorchを使ってONNXのモデルを作りたいと思います。なお、バージョンは1.0.0を用いています。(古いバージョンだとnn.onnx
が含まれていないです。)
また、手書き文字データのMNISTの学習モデルを作ることにします。
ネットワークの定義
畳み込み層2層と全結合層2層で構成されるシンプルなネットワークを定義します。
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
サポートされるオペレーター
ここで注意すべきは、ONNXやONNX.jsで使用できるオペレーター(Conv
とかReLu
)に制限があることです。厄介なことに、ONNXで使えるオペレーターすべてがONNX.jsで使えるわけではありません。
はじめ、F.log_softmax
を使っておりONNX形式でのExportは成功しても、ONNX.jsでImportでエラーが発生して読み込めませんでした。。
使えるオペレーターは、以下のページで確認できます。
- ONNX: https://github.com/onnx/onnx/blob/master/docs/Operators.md
- ONNX.js: https://github.com/Microsoft/onnxjs/blob/master/docs/operators.md
ONNX.jsでは、Backend(CPUやGPU)によっても使えるオペレーターが変わるようです。厄介ですね。
さきほどのLogSoftmax
のように基本的なオペレーターもまだ対応していなかったりするので、まだまだこれからな感じです。
モデルの学習
通常のPyTorchの学習と同じで、公式のExampleをベースに少し変更しています。(ネットワークの定義でF.log_softmax
が使えなかったのでlossの計算でcross_entropy
を用いています)
import torch
from torchvision import datasets, transforms
# 使用するデバイスの設定(GPUが使える場合はGPUを使い、使えない場合はCPUを用いる)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# モデルの読み込み
model = Net().to(device)
# データセットの準備
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
])), batch_size=args.batch_size, shuffle=True)
# オプティマイザーの準備
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
# モデルの学習
def train():
model.train()
for _, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
epochs = 10
for _ in range(epochs):
train(model, train_loader, optimizer)
ちなみに、10epochで1分くらいかかり、テストデータに対するAccuracyは0.9883
でした。
ONNXモデルへのExport
ONNXは構築された計算グラフを保存するものなので、PyTorchのようにDefine-by-Runのフレームワークの場合、入力データの形式をダミーデータを用いて指定します。値はランダムでもすべて0でも問題ないです。
ダミーデータを用意すれば、Exportはtorch.onnx.export
を用いるだけで簡単に行うことができます。
# 1x1x28x28 のダミーデータを用意する
dummy_input = torch.randn(1, 1, 28, 28, device=device)
# 学習済みモデルとダミーデータを用いてONNX形式のモデルをファイルに出力
torch.onnx.export(model, dummy_input, 'mnist.onnx')
モデルの中身を見てみる
出力されたファイルは、バイナリデータなので中身を見ることができません。そこで、onnx
を使います。
$ pip install onnx
その他のインストール方法はGithubを参照ください。
import onnx
import onnx.helper
model = onnx.load('mnist.onnx')
print(onnx.helper.printable_graph(model.graph))
以下のような出力を得ることができます。
graph torch-jit-export (
%0[FLOAT, 1x1x28x28]
) initializers (
%1[FLOAT, 20x1x5x5]
%2[FLOAT, 20]
%3[FLOAT, 50x20x5x5]
%4[FLOAT, 50]
%5[FLOAT, 500x800]
%6[FLOAT, 500]
%7[FLOAT, 10x500]
%8[FLOAT, 10]
) {
%9 = Conv[dilations = [1, 1], group = 1, kernel_shape = [5, 5], pads = [0, 0, 0, 0], strides = [1, 1]](%0, %1, %2)
%10 = Relu(%9)
%11 = MaxPool[kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]](%10)
%12 = Conv[dilations = [1, 1], group = 1, kernel_shape = [5, 5], pads = [0, 0, 0, 0], strides = [1, 1]](%11, %3, %4)
%13 = Relu(%12)
%14 = MaxPool[kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]](%13)
%15 = Constant[value = <Tensor>]()
%16 = Reshape(%14, %15)
%17 = Gemm[alpha = 1, beta = 1, transB = 1](%16, %5, %6)
%18 = Relu(%17)
%19 = Gemm[alpha = 1, beta = 1, transB = 1](%18, %7, %8)
return %19
}
Netronを使って可視化する
Netronを使うと手軽に可視化できます。Web版の他、macOSアプリなどもあります。
- Web: https://lutzroeder.github.io/netron/
- macOS:
brew cask install netron
- Python:
pip install netron
-
netron -b
Web版と同じものが起動する
さきほど紹介したONNX Model Zooから適当にONNX形式のモデルをダウンロードすると、すぐに試すことができます。
NetronはONNX専用というわけではなく、複数の形式をサポートしているので、普段の開発/学習でもWebでさくっと確認するときには良いと思います。
フレームワーク | モデルファイルの拡張子 | サポート |
---|---|---|
ONNX | .onnx, .pb, .pbtxt | ○ |
Keras | .h5, .keras | ○ |
CoreML | .mlmodel | ○ |
Caffe2 | predict_net.pb, predict_net.pbtxt | ○ |
MXNet | .model, -symbol.json | ○ |
TensorFlow Lite | .tflite | ○ |
Caffe | .caffemodel, .prototxt | △ |
PyTorch | .pth | △ |
Torch | .t7 | △ |
CNTK | .model, .cntk | △ |
PaddlePaddle | __model__ | △ |
Darknet | .cfg | △ |
scikit-learn | .pkl | △ |
TensorFlow.js | model.json, .pb | △ |
TensorFlow | .pb, .meta, .pbtxt | △ |
※サポーtの△
はexperimental support
と記載があったフレームワークです。もしかしたらうまく動かないものもあるかもしれません。
WebブラウザでONNXモデルのロード
ONNX.jsを使う方法として、<script>
タグを使う方法もありますが、ReactやVueといった最近のWebフレームワークと相性が良いので、npm
で動かす方法を試してみます。
$ npm install onnxjs
import { Tensor, InferenceSession } from 'onnxjs'
async runModel() {
// 推論に用いるセッションの初期化
// Backendには cpu や webgl, wasm を利用することができます
const session = new InferenceSession({ backendHint: 'webgl' })
// ONNX形式のモデルファイル
const modelFile = './mnist.onnx'
// モデルの読み込み
await session.loadModel(modelFile)
// Inputデータの準備(必要であれば事前に前処理をしておく必要があります)
// とりあえず、すべて0とするダミーデータを用意します
const dummy_input = new Float32Array(28 * 28).fill(0)
const inputTensor = new Tensor(dummy_input, 'float32', [1, 1, 28, 28])
// 推論の実行
const outputData = await session.run([inputTensor])
}
Canvasからデータ取得する
ダミーデータでは面白くないので、Canvasに文字を書いて読み取ることを考えます。
input-canvas
とinput-canvas-scaled
の2つのCanvasを用意しておき、input-canvas-scaled
の方は28x28のサイズにしておきdisplay: 'none'
で表示させないようにしておきます。
getImageTensor() {
// input-canvasのcontextを取得
const ctx = document.getElementById('input-canvas').getContext('2d')
// input-canvasのデータを28x28に変換して、input-canvas-scaledに書き込む
const ctxScaled = document.getElementById('input-canvas-scaled').getContext('2d')
ctxScaled.save();
ctxScaled.scale(28 / ctx.canvas.width, 28 / ctx.canvas.height)
ctxScaled.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height)
ctxScaled.drawImage(document.getElementById('input-canvas'), 0, 0)
ctxScaled.restore()
// input-canvas-scaledのデータをTensorに変換
const imageDataScaled = ctxScaled.getImageData(0, 0, 28, 28)
// console.log('imageDataScaled', imageDataScaled)
const input = new Float32Array(784);
for (let i = 0, len = imageDataScaled.data.length; i < len; i += 4) {
input[i / 4] = imageDataScaled.data[i + 3] / 255;
}
const tensor = new Tensor(input, 'float32', [1, 1, 28, 28]);
return tensor
}
先程のrunModel
でダミーデータを用いていたところをgetImageTensor
を使えば、Canvasデータを推論にかけることができます。
以下は、これらを用いて実装したデモです。
おわりに
複数のディープラーニングフレームワークで共通で使える標準フォーマットであるONNXと、ONNXモデルをブラウザで動作させるONNX.jsについて紹介しました。
ONNX.jsについてはまだ発表されたばかりで、まだドキュメント(特に日本語)が少なかったり、機能面でもまだまだな感じがしますが、Exampleは意外とちゃんとしてたり、ONNX自体も対応フレームワークも増えデファクトスタンダードになりつつあるので、ONNX.jsのこれからの開発も期待したいところです。