TensorFlowのチュートリアル(手書き画像データの画像認識)を元に、DeepLearningのネットワークデータを書き出し、Androidで手書き認識をするデモを作成してみました。
学習データの書き出し
TensorFlowのチュートリアルの"MNIST For ML Beginners" のモデルを元にまずは、学習データをPythonにてPC上で書き出します。
"MNIST For ML Beginners"
https://www.tensorflow.org/versions/master/tutorials/mnist/beginners/index.html
ここのチュートリアルを元に、グラフデータ書き出し用にスクリプトを変更しました。
ネットワークデータを書き出すには、グラフ情報とVariableに入ったテンソルデータ(学習された内容)をいっしょにして書き出す必要があるのですが、現時点ではTensorFlowは、グラフ情報とVariableをいっしょに保存することができない様でした。
なので、学習後に、Viriablesの中身を評価して、一度ndarrayに変換し、
# Store variable
_W = W.eval(sess)
_b = b.eval(sess)
ndarrayをConstantに変換し、Variablesの代わりとして利用しグラフを再構成することでグラフと学習データをまとめて書き出しました。
# グラフを再生成
g_2 = tf.Graph()
with g_2.as_default():
# 入力部分に"input"と名前付けしておく
x_2 = tf.placeholder("float", [None, 784], name="input")
# VariablesをConstantで置き換え
W_2 = tf.constant(_W, name="constant_W")
b_2 = tf.constant(_b, name="constant_b")
# 出力部分に"output"と名前付けしておく
y_2 = tf.nn.softmax(tf.matmul(x_2, W_2) + b_2, name="output")
sess_2 = tf.Session()
init_2 = tf.initialize_all_variables()
sess_2.run(init_2)
# グラフを書き出し
graph_def = g_2.as_graph_def()
tf.train.write_graph(graph_def, './tmp/beginner-export',
'beginner-graph.pb', as_text=False)
Android側で呼び出すために、入力部分と出力部分のノードにそれぞれ、"input", "output"という名前をつけておきました。
このモデルの学習と書き出しは数秒ですみました。
Android側
TensorFlowにもともと入っているAndroidデモは、Bazel環境でしかビルドできなかったので、AndroidStudioとNDKだけでAndroidアプリがビルドできる環境を作りました。
TensorFlowのAndroidサンプルをBazelでビルドした時にできるライブラリファイル(.aファイル)を保存しておき、NDKだけでビルドできる様にしました
Android.mkはこの様になりました。
LOCAL_PATH := $(call my-dir)
include $(CLEAR_VARS)
TENSORFLOW_CFLAGS := -frtti \
-fstack-protector-strong \
-fpic \
-ffunction-sections \
-funwind-tables \
-no-canonical-prefixes \
'-march=armv7-a' \
'-mfpu=vfpv3-d16' \
'-mfloat-abi=softfp' \
'-std=c++11' '-mfpu=neon' -O2 \
TENSORFLOW_SRC_FILES := ./tensorflow_jni.cc \
./jni_utils.cc \
LOCAL_MODULE := tensorflow_mnist
LOCAL_ARM_MODE := arm
LOCAL_SRC_FILES := $(TENSORFLOW_SRC_FILES)
LOCAL_CFLAGS := $(TENSORFLOW_CFLAGS)
LOCAL_LDLIBS := \
-Wl,-whole-archive \
$(LOCAL_PATH)/libs/$(TARGET_ARCH_ABI)/libandroid_tensorflow_lib.a \
$(LOCAL_PATH)/libs/$(TARGET_ARCH_ABI)/libre2.a \
$(LOCAL_PATH)/libs/$(TARGET_ARCH_ABI)/libprotos_all_cc.a \
$(LOCAL_PATH)/libs/$(TARGET_ARCH_ABI)/libprotobuf.a \
$(LOCAL_PATH)/libs/$(TARGET_ARCH_ABI)/libprotobuf_lite.a \
-Wl,-no-whole-archive \
$(NDK_ROOT)/sources/cxx-stl/gnu-libstdc++/4.9/libs/$(TARGET_ARCH_ABI)/libgnustl_static.a \
$(NDK_ROOT)/sources/cxx-stl/gnu-libstdc++/4.9/libs/$(TARGET_ARCH_ABI)/libsupc++.a \
-llog -landroid -lm -ljnigraphics -pthread -no-canonical-prefixes '-march=armv7-a' -Wl,--fix-cortex-a8 -Wl,-S \
LOCAL_C_INCLUDES += $(LOCAL_PATH)/include $(LOCAL_PATH)/genfiles $(LOCAL_PATH)/include/third_party/eigen3
NDK_MODULE_PATH := $(call my-dir)
include $(BUILD_SHARED_LIBRARY)
コンパイラオプションや、リンカオプションを上記の様に設定しないと、ビルドが通っても、学習データ(Protocol Buffersデータ)が正しく読み込めませんでした。
Java側で28x28の手書きのピクセルデータを用意して、JNIでc++側に渡し、グラフデータを元に作成したグラフに入力します。
無事に認識できました。
学習データの置き換え
上記のモデルだと、認識率が91%くらいなので、TensorFlowの"Deep MNIST for Experts"にあるDeep Learningをつかったモデル(認識率99.2%)に置き換えてみました。
"Deep MNIST for Experts"
https://www.tensorflow.org/versions/master/tutorials/mnist/pros/index.html
学習データ書き出し用スクリプト
https://github.com/miyosuda/TensorFlowAndroidMNIST/blob/master/trainer-script/expert.py
DropOutのノードが入っていると、Android側での実行時に何故かエラーが出てしまったのと、元々DropOutノードは学習時にしか必要ないものなので、書き出し時にはグラフから外しました。
自分の環境(MacBook Pro)で学習には1時間くらいかかりました。
inputノードとoutputノードの名前を同じにしておいたので、書き出し後は、学習データを差し替えるだけで、Android側のコードはすべてそのままで実行できます。
手書き認識を試してみたところ、かなり正確に0〜9の数字を認識するのが確認できました。
ソース
上記のソース一式ははこちら
https://github.com/miyosuda/TensorFlowAndroidMNIST