LoginSignup
2
1

More than 3 years have passed since last update.

TensorFlowで学習したモデルをNode.jsで利用

Posted at

TensorFlow.js というライブラリでJavaScriptからTensorFlowを利用することができるようです。

PythonのTensorFlowで学習したモデルをNode.jsで読み込んで推論に利用してみました。PythonもNode.jsもGoogle Colaboratory上です。

手順

  1. TensorFlowで学習
  2. 学習したモデルをファイルに保存
  3. モデルのファイルをTensorFlow.js用に変換
  4. TensorFlow.js をインストール
  5. Node.jsでモデルを読み込んで推論を実行

TensorFlowのバージョン情報

!pip list | grep tensorflow
tensorflow                    2.4.0          
tensorflow-addons             0.8.3          
tensorflow-datasets           4.0.1          
tensorflow-estimator          2.4.0          
tensorflow-gcs-config         2.4.0          
tensorflow-hub                0.10.0         
tensorflow-metadata           0.26.0         
tensorflow-privacy            0.2.2          
tensorflow-probability        0.11.0  

※コマンドの実行結果は2021/01/08時点です。

1. TensorFlowで学習

サンプルとしてXORを学習させます。

import numpy as np
from matplotlib import pyplot as plt
import tensorflow as tf

in_size = 2
out_size = 2

# XORの学習データ
x = np.asarray([[0, 0], [0, 1], [1, 0], [1, 1]])
y = np.asarray([0, 1, 1, 0])
y_onehot = tf.keras.backend.one_hot(y, out_size)

# モデル構造を定義
hidden_size = 2
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(hidden_size, activation='relu', input_shape=(in_size,)))
model.add(tf.keras.layers.Dense(out_size, activation='softmax'))

# モデルを構築
model.compile(
    loss = "categorical_crossentropy",
    optimizer = tf.optimizers.Adam(learning_rate=0.05),
    metrics=["accuracy"])

# 学習を実行
result = model.fit(x, y_onehot,
    batch_size=100,
    epochs=20,
    verbose=1,
    validation_data=(x, y_onehot)) # 手抜き

# 学習の様子をグラフへ描画
def plotLearning():
  # ロスの推移をプロット
  plt.plot(result.history['loss'])
  plt.plot(result.history['val_loss'])
  plt.title('Loss')
  plt.legend(['train', 'test'], loc='upper left')
  plt.show()

  # 正解率の推移をプロット
  plt.plot(result.history['accuracy'])
  plt.plot(result.history['val_accuracy'])
  plt.title('Accuracy')
  plt.legend(['train', 'test'], loc='upper left')
  plt.show()

plotLearning()

2. 学習したモデルをファイルに保存

ディレクトリ作成

!rm -rf model && mkdir model

保存

model.save("model/sample1.h5")

この名前でファイルが1つ作成されます。

3. モデルのファイルをTensorFlow.js用に変換

変換するツールをインストール

!pip install tensorflowjs

tensorflow-hub-0.9.0tensorflowjs-2.8.3 がインストールされます。

!tensorflowjs_converter --input_format=keras model/sample1.h5 model/sample-tfjs

保存結果を確認

!find model/sample-tfjs
model/sample-tfjs
model/sample-tfjs/group1-shard1of1.bin
model/sample-tfjs/model.json

ディレクトリが作成され、その中に2つのファイルが作成されました。たぶんメタ情報のJSONと数値のバイナリ。

4. TensorFlow.js をインストール

!npm install @tensorflow/tfjs-node

@tensorflow/tfjs-node@2.8.3 がインストールされます。

5. Node.jsでモデルを読み込んで推論を実行

JavaScriptソースファイル作成

%%writefile sample.js
const tf = require('@tensorflow/tfjs-node');
tf.loadLayersModel('file://model/sample-tfjs/model.json').then((model) => {
  const inputs = [[0, 0], [0, 1], [1, 0], [1, 1]];
  const result = model.predict(tf.tensor(inputs)).argMax(1).dataSync();
  console.log(result);
});

実行

!node sample.js
node-pre-gyp info This Node instance does not support builds for N-API version 6
node-pre-gyp info This Node instance does not support builds for N-API version 7
node-pre-gyp info This Node instance does not support builds for N-API version 6
node-pre-gyp info This Node instance does not support builds for N-API version 7
2021-01-08 12:43:44.532076: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2021-01-08 12:43:44.546231: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2299995000 Hz
2021-01-08 12:43:44.546491: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x44c2e00 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2021-01-08 12:43:44.546547: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
Int32Array [ 0, 1, 1, 0 ]

なにやら警告っぽいのが表示されたのですが、よくわかっていないです。

最後の [0, 1, 1, 0] でXORが実行できているのを確認できます。

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1