LoginSignup
22
20

More than 5 years have passed since last update.

TensorFlow.jsでPython学習済モデルを読み込み、ブラウザで予測結果を出力する

Last updated at Posted at 2018-07-22

はじめに

TensorFlow Developer Summit 2018 にて Webブラウザ上で機械学習のモデルの構築、学習、学習済みモデルの実行などが可能になるJavaScriptライブラリ「TensorFlow.js」がGoogleによって公開されました。
今はやりっぽいのでJavascriptの勉強もかねてちょっと動かしてみました。

目的

Python モデルをTensorFlow.js で読み込む方法メモ

作ったもの

Pythonのモデルを用いた推測結果をTensorFlow.jsのコンソール上で表示
作成物.png

コード

https://github.com/hiyashichuka/tfjs-iris
※以下を参考にしました
https://github.com/tensorflow/tfjs-examples/tree/master/iris

環境

  • OS

    • Windows 10
    • Bash on Ubuntu on Windows 14.04.5 LTS, Trusty Tahr
  • node

    • node v8.11.3
    • npm 5.6.0
  • Python

    • Python 3.5.5 |Anaconda custom (64-bit)|
  • pip list

    • tensorboard 1.8.0
    • tensorflow 1.8.0
    • tensorflow-gpu 1.8.0
    • tensorflow-hub 0.1.0
    • tensorflow-tensorboard 1.5.1
    • tensorflowjs 0.5.2

主な流れ

  1. Python学習済みモデルを作成 ⇒ Pythonで実行
  2. モデルをTensorFlow.jsで読み込める形に変換 ⇒ Pythonで実行
  3. 2の学習済みモデルを用いて、任意の花の測定値からアイリスのどの品種に属するかを予測 ⇒Javascriptで実行
  4. 予測結果をTensorFlow.jsで表示 ⇒Javascriptで実行

3,4を具体的に
任意の4データを学習済みモデルに入れて(Sepal Length(がく片の長さ), Sepal Width(がく片の幅), Petal Length(花びらの長さ), Petal Width(花びらの幅))'Setosa', 'Versicolor', 'Virginica'のどれに分類されるかをコンソール上で表示させる

1. Python学習済みモデルを作成

Irisクラス分類学習済みモデルを生成します。
今回モデルを作ることは目的ではないので、以下のGoogleのサンプルコードを使います

2. モデルをTensorFlow.jsで読み込める形に変換

1の iris_data.pyを実行します
実行環境がないならGoogle Colaboratoryを使うといいかもです

python iris_data.py

以下のファイルが\tmp\iris.kerasに生成されます
model.jsonを読み込みに使います     

group1-shard1of1  
group2-shard1of1 
model.json

   

3 学習済みモデルを用いて、任意の花の測定値からアイリスのどの品種に属するかを予測

データ読み込み

tf.loadModelからmodel.jsonを読み込みます。

  const MODEL_JSON_URL = /* model.jsonのPath */
  const model = await tf.loadModel(MODEL_JSON_URL);

テストデータを用意

今回はPetal length, Petal width, Sepal length, Sepal widthの四つのデータが必要なので、inputData で渡します

  // Input four date
  // Petal length, Petal width, Sepal length, Sepal width
  const inputData = [5.1, 3.5, 1.4, 0.2];
  const input = tf.tensor2d([inputData], [1, 4]);

予測

model.predictにinputデータを入れることで予測が可能です

  const predictOut = model.predict(input);

4. 予測結果をTensorFlow.jsで表示

結果出力

  const logits = Array.from(predictOut.dataSync());
  console.log("Setosa Probabilities :" + logits[0]);
  console.log("Versicolor Probabilities :" + logits[1]);
  console.log("Virginica Probabilities :" + logits[2]);

  const winner = IRIS_CLASSES[predictOut.argMax(-1).dataSync()[0]];
  console.log("Predict IRIS Class :" + winner);

上記結果

Setosa Probabilities :0.956368625164032
Versicolor Probabilities :0.04204682260751724
Virginica Probabilities :0.0015846537426114082
Predict IRIS Class :Setosa

Setosa > Versicolor > Virginica なのでSetosaクラスに分類されることがわかりました

メモ

  • tfjs-converterを使わない理由

GoogleがSavedModelを使うのを推奨しているため

(Note: TensorFlow has deprecated session bundle format, please switch to SavedModel.)
https://github.com/tensorflow/tfjs-converter

参考

tfjs-examples
https://github.com/tensorflow/tfjs-examples

TensorFlow.jsでMNIST学習済モデルを読み込みブラウザで手書き文字認識をする
https://qiita.com/kaneU/items/ca84c4bfcb47ac53af99

TensorFlow 1.7 新機能 サンプルコードまとめ
https://qiita.com/akimach/items/d150fce405aff37dd463#get-started.md

22
20
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
22
20