はじめに
TensorFlow Developer Summit 2018 にて Webブラウザ上で機械学習のモデルの構築、学習、学習済みモデルの実行などが可能になるJavaScriptライブラリ「TensorFlow.js」がGoogleによって公開されました。
今はやりっぽいのでJavascriptの勉強もかねてちょっと動かしてみました。
目的
Python モデルをTensorFlow.js で読み込む方法メモ
作ったもの
Pythonのモデルを用いた推測結果をTensorFlow.jsのコンソール上で表示
コード
https://github.com/hiyashichuka/tfjs-iris
※以下を参考にしました
https://github.com/tensorflow/tfjs-examples/tree/master/iris
環境
-
OS
- Windows 10
- Bash on Ubuntu on Windows 14.04.5 LTS, Trusty Tahr
-
node
- node v8.11.3
- npm 5.6.0
-
Python
- Python 3.5.5 |Anaconda custom (64-bit)|
-
pip list
- tensorboard 1.8.0
- tensorflow 1.8.0
- tensorflow-gpu 1.8.0
- tensorflow-hub 0.1.0
- tensorflow-tensorboard 1.5.1
- tensorflowjs 0.5.2
主な流れ
- Python学習済みモデルを作成 ⇒ Pythonで実行
- モデルをTensorFlow.jsで読み込める形に変換 ⇒ Pythonで実行
- 2の学習済みモデルを用いて、任意の花の測定値からアイリスのどの品種に属するかを予測 ⇒Javascriptで実行
- 予測結果をTensorFlow.jsで表示 ⇒Javascriptで実行
3,4を具体的に
任意の4データを学習済みモデルに入れて(Sepal Length(がく片の長さ), Sepal Width(がく片の幅), Petal Length(花びらの長さ), Petal Width(花びらの幅))'Setosa', 'Versicolor', 'Virginica'のどれに分類されるかをコンソール上で表示させる
1. Python学習済みモデルを作成
Irisクラス分類学習済みモデルを生成します。
今回モデルを作ることは目的ではないので、以下のGoogleのサンプルコードを使います
2. モデルをTensorFlow.jsで読み込める形に変換
1の iris_data.pyを実行します
実行環境がないならGoogle Colaboratoryを使うといいかもです
python iris_data.py
以下のファイルが\tmp\iris.kerasに生成されます
model.jsonを読み込みに使います
group1-shard1of1
group2-shard1of1
model.json
3 学習済みモデルを用いて、任意の花の測定値からアイリスのどの品種に属するかを予測
データ読み込み
tf.loadModelからmodel.jsonを読み込みます。
const MODEL_JSON_URL = /* model.jsonのPath */
const model = await tf.loadModel(MODEL_JSON_URL);
テストデータを用意
今回はPetal length, Petal width, Sepal length, Sepal widthの四つのデータが必要なので、inputData で渡します
// Input four date
// Petal length, Petal width, Sepal length, Sepal width
const inputData = [5.1, 3.5, 1.4, 0.2];
const input = tf.tensor2d([inputData], [1, 4]);
予測
model.predictにinputデータを入れることで予測が可能です
const predictOut = model.predict(input);
4. 予測結果をTensorFlow.jsで表示
結果出力
const logits = Array.from(predictOut.dataSync());
console.log("Setosa Probabilities :" + logits[0]);
console.log("Versicolor Probabilities :" + logits[1]);
console.log("Virginica Probabilities :" + logits[2]);
const winner = IRIS_CLASSES[predictOut.argMax(-1).dataSync()[0]];
console.log("Predict IRIS Class :" + winner);
上記結果
Setosa Probabilities :0.956368625164032
Versicolor Probabilities :0.04204682260751724
Virginica Probabilities :0.0015846537426114082
Predict IRIS Class :Setosa
Setosa > Versicolor > Virginica なのでSetosaクラスに分類されることがわかりました
メモ
- tfjs-converterを使わない理由
GoogleがSavedModelを使うのを推奨しているため
(Note: TensorFlow has deprecated session bundle format, please switch to SavedModel.)
https://github.com/tensorflow/tfjs-converter
参考
tfjs-examples
https://github.com/tensorflow/tfjs-examples
TensorFlow.jsでMNIST学習済モデルを読み込みブラウザで手書き文字認識をする
https://qiita.com/kaneU/items/ca84c4bfcb47ac53af99
TensorFlow 1.7 新機能 サンプルコードまとめ
https://qiita.com/akimach/items/d150fce405aff37dd463#get-started.md