3
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

JavaScriptからTensorFlow Liteのモデルを使って推論してみる

Last updated at Posted at 2021-08-22

はじめに

TensorFlow Lite なるものをご存知でしょうか。主にエッジ用途など低リソース環境のために、重みパラメータを整数化してモデルサイズの縮小と計算量の削減を行うというものです。

実はこのモデル、JavaScriptからでも使えます。モデルをダウンロードしてクライアントサイドで推論する場合、モデルサイズの削減はトラフィック削減にもつながるので嬉しいです。また、計算量が減ることは比較的CPUが非力なモバイル用途においてもよいと思うので、Cordovaなどを使ってHTML/JavaScriptベースでモバイルアプリを作る場合に活用することもできそうです。

ということで、簡単なMNIST手書き数字認識のモデルを学習して重みパラメータを整数化し、JavaScriptから呼び出すところまでを一通り作ってみたいと思います。

検証環境

ご覧のように tfjs-tflite はアルファ版です。きっと未サポートの機能も多々あるので、どんなモデルでも使えるというわけにはいかないかもしれませんが、基本的なレイヤーを使っていれば大丈夫なのでしょう(少なくとも Conv2DMaxPooling2D は使えています)。

モデル学習

前回の記事と同様に、まずはMNISTを使って適当なモデルを学習します。

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.datasets import mnist

# データセットの準備
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train[:, :, :, np.newaxis].astype("float32") / 255.0  # (60000, 28, 28, 1)
x_test = x_test[:, :, :, np.newaxis].astype("float32") / 255.0    # (10000, 28, 28, 1)

# モデルの準備
model = Sequential([
  Conv2D(32, kernel_size=(3, 3), activation='relu', padding="same", input_shape=(28, 28, 1)),
  MaxPooling2D(pool_size=(2, 2)),
  Conv2D(64, kernel_size=(3, 3), activation='relu', padding="same"),
  MaxPooling2D(pool_size=(2, 2)),
  Flatten(),
  Dense(128, activation='relu'),
  Dense(10, activation='softmax')
])
model.compile("Adam", loss="sparse_categorical_crossentropy", metrics="sparse_categorical_accuracy")

# 学習
model.fit(x_train, y_train, batch_size=256, epochs=10, validation_data=(x_test, y_test))

保存するときはこんな感じです。

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
def representative_data_gen():
  for input_value in tf.data.Dataset.from_tensor_slices(x_train).batch(1).take(100):
    # Model has only one input so each data point has one element.
    yield [input_value]
converter.representative_dataset = representative_data_gen
tflite_model_int8 = converter.convert()
with open("model_int8.tflite", "wb") as f:
  f.write(tflite_model_int8)

重みパラメータを量子化する時には、値の取りうる範囲を調べておかないと効率的な量子化ができないので、学習データの一部を入力することでそれを行います。

変換後のモデルは重みがfloat32 → int8になっていますので、単純計算でモデルサイズが1/4になります。実際にint8になっていることを確認するため、ファイルサイズを確認すると

!ls -l model_int8.tflite
-rw-r--r-- 1 root root 428560 Aug 21 13:45 ./model_int8.tflite

となっています。一方、model.summary() の結果は以下のようになります。

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 28, 28, 32)        320       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 14, 14, 64)        18496     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 7, 7, 64)          0         
_________________________________________________________________
flatten (Flatten)            (None, 3136)              0         
_________________________________________________________________
dense (Dense)                (None, 128)               401536    
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1290      
=================================================================
Total params: 421,642
Trainable params: 421,642
Non-trainable params: 0
_________________________________________________________________

「パラメータ数+α」くらいのファイルサイズ(バイト)になっているので、確かに重みパラメータが1バイト (int8) で表現されているらしいと分かります。

このモデルをGoogle Drive経由でダウンロードしておきます。詳細は前回の記事をご覧ください。

推論

TensorFlow.js の Task API を用いる方法

少し調べた範囲だと、TensorFlow.js の Task API 経由で呼び出す方法が出てきました。

モデルに合うTaskが以下のページにあれば使えますが、任意のモデルを対象にできるわけではないようです。
@tensorflow-models/tasks - npm

例えば画像分類であれば以下のページにモデルの要件が書かれていますが
画像分類器の統合 | TensorFlow Lite

RGB 入力のみがサポートされています (channelsは 3 である必要があります)。

とあるように、入力はカラー画像でなければなりません。同じ画像分類であっても、MNISTの手書き文字認識のようなグレースケール画像の場合だと使えません。

定義済みTaskに合わないモデルは、どうしようもないのでしょうか?

TensorFlow Lite APIを直接呼び出す方法

実はTensorFlow.jsから使えるTensorFlow LiteのAPIがあるので、これを直接使うことが可能です。こんな感じでHTMLとスクリプトを記述し、tflite.html という名前で保存します。

tflite.html
<!DOCTYPE html>
<html lang="ja">
  <head>
    <meta charset="utf-8">
  </head>

  <body>
    <div><img src="sample_5.png" id="input_image"></div>
    <div>Prediction: <span id="pred"></span></div>

    <!-- Load TensorFlow.js -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"></script>

    <!-- Import @tensorflow/tfjs-tflite -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-tflite/dist/tf-tflite.min.js"></script>

    <!-- Place your code in the script tag below. You can also use an external .js file -->
    <script>
      // Notice there is no 'import' statement. 'tf' is available on the index-page
      // because of the script tag above.

      async function run(){
        const model = await tflite.loadTFLiteModel('./model_int8.tflite');
        const x = tf.browser.fromPixels(document.getElementById('input_image'), 1).reshape([1, 28, 28, 1]).div(tf.scalar(255));
        const y = model.predict(x); // [1, 10]
        document.getElementById('pred').textContent = y.reshape([-1]).argMax().arraySync();
      }
      run();
    </script>
  </body>
</html>

model_int8.tflite と、入力画像 sample_5.png がHTMLと同じ場所にあるとします。http://https:// でアクセスできる場所を用意してください。sample_5.png は適当にペイントのソフトで描いたこの画像です。
sample_5.png
見た感じ、ほとんど前回の記事と同じです。差分を見ていきましょう。

まず、TensorFlow LiteのAPIを使用するため、スクリプトを追加で読み込んでいます。

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-tflite/dist/tf-tflite.min.js"></script>

モデルを読み込むためのAPIが変わります。

const model = await tflite.loadTFLiteModel('./model_int8.tflite');

えーと、差分はこれだけです。

実行してみる

ブラウザから tflite.html のURLにアクセスします。数秒待って「Prediction: 5」と表示されたら成功です。やってみたら意外と簡単でした。
image.png

3
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?