LoginSignup
1
1

More than 1 year has passed since last update.

ONNXファイルを自力で読み込むだけ

Last updated at Posted at 2022-07-19

はじめに

ライブラリを使わずに.onnxファイルを読み込むまとまった記事が見つからず四苦八苦したので、ここにまとめます。

最終的には、中間層が二つあるFFNNをPyTorchで生成し.onnxファイルを出力、そしてそれをJavaScript(ES Module)により読み取ります。
とりあえず読み込んだという内容ですので、次のような内容はありません。

  • 汎用的な.onnxフォーマットの読み込み。
  • 読み込んだデータで実際に計算処理を実行。

なお、環境はLinux(Ubuntu)です。

ONNXとは

ONNXとは、ニューラルネットワークのモデルをいい感じに保存するフォーマットの一つです。
フォーマットはProtocol Buffersにより記述されており、また.onnxで表されるONNXファイルはProtocol Buffersによってシリアライズされたファイルとなります。

有名なニューラルネットワークのライブラリ(TensorflowやPyTorchなど)には、ONNXファイルの読み書きをするライブラリが用意されています。
そのため、普段は中のフォーマットを意識することは無いと思います。

具体的なフォーマットについては長くなるので一部を除いて割愛します。知りたい方はONNXのGithubリポジトリを参照してください。

下準備

JavaScriptで.onnxファイルを読み込むためのコードを、シェルスクリプトで生成します。

スクリプトファイル内に記載したため、カレントディレクトリ非依存な書き方になっています。

protocのインストール

onnx.protoファイルから各言語用の読み書き用のコードを生成するためのアプリケーションです。

githubにおいてあるので、適当なバージョンのものを持ってきます。

// 作業ディレクトリ
WORK_DIR="/work"
// protocのバージョン
PROTOBUF_VERSION="3.20.1"
// protocの実行ファイル
PROTOC="${WORK_DIR}/bin/protoc"

wget -P "${WORK_DIR}/" https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOBUF_VERSION}/protoc-${PROTOBUF_VERSION}-linux-x86_64.zip
unzip "${WORK_DIR}/protoc-${PROTOBUF_VERSION}-linux-x86_64.zip" -d ${WORK_DIR}

ts-protoc-genのインストール

TypeScriptの型定義も一緒に作りたいので、これもインストールします。

// protoc-gen-tsの実行ファイル
PROTOC_GEN_TS_PATH="${WORK_DIR}/node_modules/.bin/protoc-gen-ts"

// npm initしてもいいと思う
echo "{}" > "${WORK_DIR}/package.json"
npm install --prefix "${WORK_DIR}" ts-protoc-gen

onnx定義ファイルの取得

Github上のonnx定義ファイルを取得します。

git clone https://github.com/onnx/onnx.git "${WORK_DIR}/onnx"

読み込み用コード生成

protocを使用し、onnx.protoから.onnxファイル読み書きコードを生成します。
出力はCommonJS形式とし、TypeScriptの型定義も出力するようにします。更に、読み込む側はES Moduleで書く予定なので、拡張子は.cjsおよび.ctsに変換します。

なおES Module形式の出力について、Issueは立っており有志のコードはあるようですが、今回は使用していません。

// 出力ディレクトリ
OUT_DIR="./out"

$PROTOC --proto_path=${WORK_DIR}/onnx/onnx \
  --plugin="protoc-gen-ts=${PROTOC_GEN_TS_PATH}" \
  --js_out="import_style=commonjs,binary:${OUT_DIR}" \
  --ts_out="${OUT_DIR}" \
  ${WORK_DIR}/onnx/onnx/onnx.proto

// 末尾の拡張子の頭にcを追加
ls -1 "${OUT_DIR}" | grep -E "onnx_pb.*\..s" | while read filename; do
  mv "${OUT_DIR}/${filename}" "${OUT_DIR}/${filename%.*}.c${filename##*.}"
done

これで$OUT_DIRonnx_pb.cjsonnx_pb.d.ctsファイルが生成されます。

ちなみにonnx.proto3ファイルもありますが、こちらではうまく作ることができませんでした。

.onnxファイル生成

ONNXのGithubリポジトリ内にもいくつかサンプルがありますが、汎用性を考えて自分で作ります。

とりあえず手頃だったPyTorchを使って.onnxファイルを生成します。

import os

from torch import nn
import torch.onnx


class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.fc1 = nn.Linear(10, 3)
        self.fc2 = nn.Linear(3, 2)

    def forward(self, x):
        x = self.fc1(x)
        return self.fc2(x)


torch_model = SimpleNet()

x = torch.randn(100, 10)
torch_out = torch_model(x)

torch.onnx.export(
    torch_model,
    x,
    f"{os.path.dirname(__file__)}/test.onnx",
    export_params=True,
    opset_version=10,
    do_constant_folding=True,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={
        "input": {0: "batch_size"},
        "output": {0: "batch_size"},
    },
)

これを実行すると、同じ場所にtest.onnxファイルが生成されます。

読み込み処理

.onnxファイル読み取り

先に生成されたonnx_pb.cjsを使って読み込みます。

import fs from 'fs'
import path from 'path'
import url from 'url'

import onnx from './onnx_pb.cjs'

const filepath = path.dirname(url.fileURLToPath(import.meta.url))
const buf = await fs.promises.readFile(`${filepath}/test.onnx`)

const modelProto = onnx.ModelProto.deserializeBinary(buffer)
const model = modelProto.toObject()

ここから、各レイヤーの情報を取得していきます。

入出力層

以下の通り、それぞれ入力と出力の情報を取得できます。

const inputNode = model.graph.inputList[0]
// model.graph.nodeListのうちで、inputListにinputNode.nameが存在する場所への入力となる

const outputNode = model.graph.outputList[0]
// model.graph.nodeListのうちで、outputListにoutputNode.nameが存在する場所からの出力となる

今回はそれぞれ一つであることが分かっているので、一つだけ取り出しました。

Tensorの読み込み処理

model.graph.initializerListなどに存在する値を実際の配列に変換する関数を作ります。

test.onnxには全て、32bit浮動小数点がrawDataに文字列として保存されていたので、それを次の手順で読み取ります。

  • 文字列をBase64でデコードする
  • Uint8Arrayに変換する
  • IEEE 754かつリトルエンディアンであることに注意して、32bitの浮動小数点に変換する
  • ネストした配列に変換する

以下が関数全体になります。

/**
 * Return Tensor value.
 *
 * @param {onnx.TensorProto.AsObject} tensor TensorProto
 * @returns {number[] | number[][]} Tensor value
 * @see https://github.com/onnx/onnx/blob/e450bc038115dd9ff5ab47670eeaf2a584105064/onnx/onnx.proto#L476
 */
const loadTensor = tensor => {
  const dims = tensor.dimsList
  const length = dims.reduce((s, v) => s * v, 1)
  let rawdata = tensor.rawData
  if (typeof rawdata === 'string') {
    // 文字列をBase64でデコードする
    const buff = Buffer.from(rawdata, 'base64')
    // Uint8Arrayに変換する
    rawdata = new Uint8Array(buff.buffer, buff.byteOffset, buff.byteLength / Uint8Array.BYTES_PER_ELEMENT)
  }

  // リトルエンディアンかつIEEE 754であることに注意して、32bitの浮動小数点に変換する
  const step = rawdata.length / length
  const arr = []
  for (let i = 0; i < rawdata.length; i += step) {
    const sign = rawdata[i + 3] & 0x80 ? -1 : 1
    const exponent = (rawdata[i + 3] & 0x7f) * 2 + ((rawdata[i + 2] & 0x80) >>> 7)
    const exp = exponent === 0 ? 0 : exponent - 127
    const fraction = (rawdata[i + 2] & 0x7f) * 2 ** -7 + rawdata[i + 1] * 2 ** -15 + rawdata[i + 0] * 2 ** -23
    arr.push(sign * (fraction + 1) * 2 ** exp)
  }

  // ネストした配列に変換する
  const ten = []
  let leaf = [ten]
  let c = 0
  for (let i = 0; i < dims.length; i++) {
    const next_leaf = []
    for (const l of leaf) {
      if (i === dims.length - 1) {
        l.push(...arr.slice(c, c + dims[i]))
        c += dims[i]
      } else {
        for (let k = 0; k < dims[i]; k++) {
          next_leaf.push((l[k] = []))
        }
      }
    }
    leaf = next_leaf
  }
  return ten
}

Attributeの読み込み処理

model.graph.nodeList[0].attributeListに存在する値を、実際の値に変換する関数を作ります。

といってもtest.onnxにある型は一部しかないので、それのみ対応しました。

以下が関数全体になります。

/**
 * Return attribute value.
 *
 * @param {onnx.AttributeProto.AsObject} attribute AttributeProto
 * @returns {*} Attribute value
 * @see https://github.com/onnx/onnx/blob/e450bc038115dd9ff5ab47670eeaf2a584105064/onnx/onnx.proto#L108
 */
const loadAttribute = attribute => {
  switch (attribute.type) {
    case 1:
      return attribute.f
    case 2:
      return attribute.i
  }
  throw new Error('Not implemented attribute type.')
}

中間層

中間層は以下の通り取得します。

for (const node of model.graph.nodeList) {
  if (node.opType !== 'Gemm') {
    throw new Error('Not implemented node opType.')
  }

  // 次のコードはここに書かれていると考えてください
}

今回はopTypeGemmのもののみしか存在しないため、それ以外の場合は一応エラーを吐くようにしておきました。

Gemm演算の読み込み

Gemm演算の定義によると、これは次の式を計算します。
$$
\boldsymbol{Y} = \alpha \boldsymbol{A} \boldsymbol{B} + \beta \boldsymbol{C}
$$
分かりにくいかもしれませんが、つまり全結合層です。
各要素は次のとおりの場所に定義されています。

  • Attributes (node.attributeListにある)
    • alpha
    • beta
    • transA (これが$1$の時、上の式の$\boldsymbol{A}$は$\boldsymbol{A}^\mathsf{T}$となる)
    • transB (これが$1$の時、上の式の$\boldsymbol{B}$は$\boldsymbol{B}^\mathsf{T}$となる)
  • Inputs (node.inputListにある)
    • A (入力値)
    • B (重み)
    • C (バイアス)
  • Outputs (node.outputListにある)
    • Y

必要な情報を順番に読み取ると、以下の通りになります。

const attrs = {}
for (const attribute of node.attributeList) {
  attrs[attribute.name] = loadAttribute(attribute)
}

const inits = {}
for (const initializer of model.graph.initializerList) {
  if (initializer.name === node.inputList[1]) {
    inits.w = loadTensor(initializer)
    if (attrs.transB) {
      // Transpose w
    }
    // Calculate w * attrs.alpha
  } else if (initializer.name === node.inputList[2]) {
    inits.b = loadTensor(initializer)
    // Calculate b * attrs.beta
  }
}

inits.wが重み、inits.bがバイアスになります。

なお一つ目の入力$\boldsymbol{A}$は他の場所から渡されます。その場所は、次のいずれかになります。

  • node.inputList[0]の値がinputNode.nameであれば、ニューラルネットワークの入力
  • 他の中間層のうちで、node.outputListnode.inputList[0]の値が存在する場所の出力

おわりに

これで必要なニューラルネットワークの情報が取得できました。あとはこれらを使って計算するだけです。
今回は扱いませんでしたが、他の演算の読み込みについてもGemmと同じ流れになると思います。

実際は、型(整数型、浮動小数点型、複素数型など)や細かいバージョンの違いがあります。
汎用的な実装には、まだ道は遠そうです。

なお、なぜこのようなことを行ったかというと、自作の機械学習パッケージのニューラルネットワークモデルをONNX対応しようと試行錯誤しているためです。
ここに記載した処理はまだ載っていませんが、ニューラルネットワークは実装済みです。
よろしければご覧ください。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1