Edited at

TensorFlow.jsでPython学習済モデルを読み込み、ブラウザで予測結果を出力する

More than 1 year has passed since last update.


はじめに

TensorFlow Developer Summit 2018 にて Webブラウザ上で機械学習のモデルの構築、学習、学習済みモデルの実行などが可能になるJavaScriptライブラリ「TensorFlow.js」がGoogleによって公開されました。

今はやりっぽいのでJavascriptの勉強もかねてちょっと動かしてみました。


目的

Python モデルをTensorFlow.js で読み込む方法メモ


作ったもの

Pythonのモデルを用いた推測結果をTensorFlow.jsのコンソール上で表示

作成物.png


コード

https://github.com/hiyashichuka/tfjs-iris

※以下を参考にしました

https://github.com/tensorflow/tfjs-examples/tree/master/iris


環境



  • OS


    • Windows 10

    • Bash on Ubuntu on Windows 14.04.5 LTS, Trusty Tahr




  • node


    • node v8.11.3

    • npm 5.6.0




  • Python


    • Python 3.5.5 |Anaconda custom (64-bit)|




  • pip list


    • tensorboard 1.8.0

    • tensorflow 1.8.0

    • tensorflow-gpu 1.8.0

    • tensorflow-hub 0.1.0

    • tensorflow-tensorboard 1.5.1

    • tensorflowjs 0.5.2




主な流れ


  1. Python学習済みモデルを作成 ⇒ Pythonで実行

  2. モデルをTensorFlow.jsで読み込める形に変換 ⇒ Pythonで実行

  3. 2の学習済みモデルを用いて、任意の花の測定値からアイリスのどの品種に属するかを予測 ⇒Javascriptで実行

  4. 予測結果をTensorFlow.jsで表示 ⇒Javascriptで実行

3,4を具体的に

任意の4データを学習済みモデルに入れて(Sepal Length(がく片の長さ), Sepal Width(がく片の幅), Petal Length(花びらの長さ), Petal Width(花びらの幅))'Setosa', 'Versicolor', 'Virginica'のどれに分類されるかをコンソール上で表示させる


1. Python学習済みモデルを作成

Irisクラス分類学習済みモデルを生成します。

今回モデルを作ることは目的ではないので、以下のGoogleのサンプルコードを使います

https://github.com/tensorflow/tfjs-examples/tree/master/iris/python


2. モデルをTensorFlow.jsで読み込める形に変換

1の iris_data.pyを実行します

実行環境がないならGoogle Colaboratoryを使うといいかもです

python iris_data.py

以下のファイルが\tmp\iris.kerasに生成されます

model.jsonを読み込みに使います     

group1-shard1of1  

group2-shard1of1 
model.json

   


3 学習済みモデルを用いて、任意の花の測定値からアイリスのどの品種に属するかを予測


データ読み込み

tf.loadModelからmodel.jsonを読み込みます。

  const MODEL_JSON_URL = /* model.jsonのPath */

const model = await tf.loadModel(MODEL_JSON_URL);


テストデータを用意

今回はPetal length, Petal width, Sepal length, Sepal widthの四つのデータが必要なので、inputData で渡します

  // Input four date

// Petal length, Petal width, Sepal length, Sepal width
const inputData = [5.1, 3.5, 1.4, 0.2];
const input = tf.tensor2d([inputData], [1, 4]);


予測

model.predictにinputデータを入れることで予測が可能です

  const predictOut = model.predict(input);


4. 予測結果をTensorFlow.jsで表示


結果出力

  const logits = Array.from(predictOut.dataSync());

console.log("Setosa Probabilities :" + logits[0]);
console.log("Versicolor Probabilities :" + logits[1]);
console.log("Virginica Probabilities :" + logits[2]);

const winner = IRIS_CLASSES[predictOut.argMax(-1).dataSync()[0]];
console.log("Predict IRIS Class :" + winner);

上記結果

Setosa Probabilities :0.956368625164032

Versicolor Probabilities :0.04204682260751724
Virginica Probabilities :0.0015846537426114082
Predict IRIS Class :Setosa

Setosa > Versicolor > Virginica なのでSetosaクラスに分類されることがわかりました


メモ


  • tfjs-converterを使わない理由

GoogleがSavedModelを使うのを推奨しているため


(Note: TensorFlow has deprecated session bundle format, please switch to SavedModel.)

https://github.com/tensorflow/tfjs-converter



参考

tfjs-examples

https://github.com/tensorflow/tfjs-examples

TensorFlow.jsでMNIST学習済モデルを読み込みブラウザで手書き文字認識をする

https://qiita.com/kaneU/items/ca84c4bfcb47ac53af99

TensorFlow 1.7 新機能 サンプルコードまとめ

https://qiita.com/akimach/items/d150fce405aff37dd463#get-started.md