機械学習勉強中の初心者です。
この投稿をみて自分も実装してみました。、
https://qiita.com/yukagil/items/ca84c4bfcb47ac53af99
つくったもの
ピンクの破線内に文字を書くと精度よく検出します。スマホでも動きます。
触発された投稿との違いは下記だと思います。
- checkpointのデータからtensorflowjs用のモデルを作成している
- vuetifyつかった
- モデルデータのダウンロード状況表示
- スマホでも動く
コードとか
せつめいとか
PythonでMNISTを分類できるモデルは作っていたのですが、下記の手順が必要でした。
- checkpointとして保存したデータをpbファイルに変換する
- pbファイルのデータをTensorflow.js用に変換する
- Tensorflow.jsでモデルを読み込む
1. checkpointとして保存したデータをpbファイルに変換する
私はcheckpointとしてファイルを保存していたのですが、それをmodel(pbファイル)として保存する必要があるようでした。
python初心者なので間違いも多いと思いすが、下記のコードで変換できました。
check_point/my_model-1000.meta
が保存しているグラフです。
# coding:utf-8
# tensorflow version1.13.1
import tensorflow as tf
saver = tf.train.import_meta_graph('check_point/my_model-1000.meta',clear_devices=True)
with tf.Session() as sess:
chpt_state = tf.train.get_checkpoint_state('check_point/')
if chpt_state:
last_model = chpt_state.model_checkpoint_path
saver.restore(sess,last_model)
print ("model was loaded",last_model)
else:
print ("model cannot loaded")
exit(1)
graph = tf.get_default_graph()
graph_def = graph.as_graph_def()
x = graph.get_tensor_by_name('x:0')
out = graph.get_tensor_by_name('reduce/out:0')
tf.saved_model.simple_save(sess, './models', inputs={"x": x}, outputs={"reduce/out": out})
これでpbファイルが生成できます。
2. pbファイルのデータをTensorflow.js用に変換する
ココでハマりました。
よくわからないですが、tensorlowjs 1.0.1のtensorflow_converterを使用しようとするとかきのようなエラーで止まってしまいます。
$ tensorflowjs_converter --input_format tf_saved_model --output_format tfjs_graph_model --saved_model_tags serve models/ out/
Traceback (most recent call last):
File "/usr/local/bin/tensorflowjs_converter", line 10, in <module>
sys.exit(main())
File "/usr/local/lib/python2.7/dist-packages/tensorflowjs/converters/converter.py", line 358, in main
strip_debug_ops=FLAGS.strip_debug_ops)
File "/usr/local/lib/python2.7/dist-packages/tensorflowjs/converters/tf_saved_model_conversion_v2.py", line 268, in convert_tf_saved_model
model = load(saved_model_dir, saved_model_tags)
TypeError: load() takes exactly 1 argument (2 given)
この問題は、tensorflowjs の0.8.5を利用することで解決できました。
sudo pip install tensorflow==0.8.5
tensorflowjs_converterのオプションが1.0系と違うので下記のようにオプションを与えて変換しました。
## tensorflowjs==0.8.5を使用
$ tensorflowjs_converter --input=tf_saved_model --output_node_names='reduce/out' --output_json='model.json' pb_models/ web_model/
$ ls web_model/
group1-shard1of4 group1-shard2of4 group1-shard3of4 group1-shard4of4 model.json weights_manifest.json
3. Tensorflow.jsでモデルを読み込む
モデルのダウンロード状況が知りたかったので、一旦モデルをzipで取得するようにしています。
tf.io.browserFiles
でファイルを読み込んだあとに、tf.loadGraphModel
を利用しています。
そのおかげで、くるくる回るスピナーを実装できました。
import * as tf from '@tensorflow/tfjs';
import _ from 'lodash';
import axios from 'axios';
import JSZip from 'jszip';
export default {
load : async (onDownloadProgress=_.noop) => {
let { data:modelData } = await axios.get('/web_model/model.zip',{responseType:'blob',onDownloadProgress})
let zip = await new JSZip().loadAsync(modelData);
let sortedZipFiles = [
...zip.filter((path,file)=>/model.json/.test(path)),
...zip.filter((path,file)=>!/model.json/.test(path))];
let fileNames = sortedZipFiles.map(zipObj=>zipObj.name)
let arraybuffers = await Promise.all(sortedZipFiles.map(zipObj=>zipObj.async("arraybuffer")))
let modelFiles = _.zip(fileNames,arraybuffers)
.map(([name,arraybuffer])=> new File([arraybuffer],name));
return tf.loadGraphModel(tf.io.browserFiles(modelFiles))
}
}
感想とか
イケてる感じで実装できてよかったです。
初めてVuetifyを使ったのでVue.jsの実装が微妙なところも多いかもしれないのが気になります。
Tensorflow.js 使うメリットってフロントエンドで実装できることだと思うのですが、モデルデータとか盗まれ放題な気がするのでどうなんでしょう?
githubのリンク作るとき https://ghlinkcard.com/ で作りましたがとても便利ですね