LoginSignup
130

More than 3 years have passed since last update.

TensorFlow: Pythonで学習したデータをAndroidで実行

Last updated at Posted at 2016-01-19

TensorFlowのチュートリアル(手書き画像データの画像認識)を元に、DeepLearningのネットワークデータを書き出し、Androidで手書き認識をするデモを作成してみました。

tensorflow_mnist_screen0.png

学習データの書き出し

TensorFlowのチュートリアルの"MNIST For ML Beginners" のモデルを元にまずは、学習データをPythonにてPC上で書き出します。

"MNIST For ML Beginners"
https://www.tensorflow.org/versions/master/tutorials/mnist/beginners/index.html

ここのチュートリアルを元に、グラフデータ書き出し用にスクリプトを変更しました。

ネットワークデータを書き出すには、グラフ情報とVariableに入ったテンソルデータ(学習された内容)をいっしょにして書き出す必要があるのですが、現時点ではTensorFlowは、グラフ情報とVariableをいっしょに保存することができない様でした。

なので、学習後に、Viriablesの中身を評価して、一度ndarrayに変換し、

# Store variable
_W = W.eval(sess)
_b = b.eval(sess)

ndarrayをConstantに変換し、Variablesの代わりとして利用しグラフを再構成することでグラフと学習データをまとめて書き出しました。

# グラフを再生成
g_2 = tf.Graph()
with g_2.as_default():
    # 入力部分に"input"と名前付けしておく
    x_2 = tf.placeholder("float", [None, 784], name="input")

    # VariablesをConstantで置き換え
    W_2 = tf.constant(_W, name="constant_W")
    b_2 = tf.constant(_b, name="constant_b")

    # 出力部分に"output"と名前付けしておく
    y_2 = tf.nn.softmax(tf.matmul(x_2, W_2) + b_2, name="output")

    sess_2 = tf.Session()
    init_2 = tf.initialize_all_variables()
    sess_2.run(init_2)

    # グラフを書き出し
    graph_def = g_2.as_graph_def()
    tf.train.write_graph(graph_def, './tmp/beginner-export',
                         'beginner-graph.pb', as_text=False)

Android側で呼び出すために、入力部分と出力部分のノードにそれぞれ、"input", "output"という名前をつけておきました。

このモデルの学習と書き出しは数秒ですみました。

Android側

TensorFlowにもともと入っているAndroidデモは、Bazel環境でしかビルドできなかったので、AndroidStudioとNDKだけでAndroidアプリがビルドできる環境を作りました。

TensorFlowのAndroidサンプルをBazelでビルドした時にできるライブラリファイル(.aファイル)を保存しておき、NDKだけでビルドできる様にしました

Android.mkはこの様になりました。

Makefile
LOCAL_PATH := $(call my-dir)

include $(CLEAR_VARS)

TENSORFLOW_CFLAGS     := -frtti \
  -fstack-protector-strong \
  -fpic \
  -ffunction-sections \
  -funwind-tables \
  -no-canonical-prefixes \
  '-march=armv7-a' \
  '-mfpu=vfpv3-d16' \
  '-mfloat-abi=softfp' \
  '-std=c++11' '-mfpu=neon' -O2 \

TENSORFLOW_SRC_FILES := ./tensorflow_jni.cc \
    ./jni_utils.cc \

LOCAL_MODULE    := tensorflow_mnist
LOCAL_ARM_MODE  := arm
LOCAL_SRC_FILES := $(TENSORFLOW_SRC_FILES)
LOCAL_CFLAGS    := $(TENSORFLOW_CFLAGS)

LOCAL_LDLIBS    := \
    -Wl,-whole-archive \
    $(LOCAL_PATH)/libs/$(TARGET_ARCH_ABI)/libandroid_tensorflow_lib.a \
    $(LOCAL_PATH)/libs/$(TARGET_ARCH_ABI)/libre2.a \
    $(LOCAL_PATH)/libs/$(TARGET_ARCH_ABI)/libprotos_all_cc.a \
    $(LOCAL_PATH)/libs/$(TARGET_ARCH_ABI)/libprotobuf.a \
    $(LOCAL_PATH)/libs/$(TARGET_ARCH_ABI)/libprotobuf_lite.a \
    -Wl,-no-whole-archive \
    $(NDK_ROOT)/sources/cxx-stl/gnu-libstdc++/4.9/libs/$(TARGET_ARCH_ABI)/libgnustl_static.a \
    $(NDK_ROOT)/sources/cxx-stl/gnu-libstdc++/4.9/libs/$(TARGET_ARCH_ABI)/libsupc++.a \
    -llog -landroid -lm -ljnigraphics -pthread -no-canonical-prefixes '-march=armv7-a' -Wl,--fix-cortex-a8 -Wl,-S \

LOCAL_C_INCLUDES += $(LOCAL_PATH)/include $(LOCAL_PATH)/genfiles $(LOCAL_PATH)/include/third_party/eigen3

NDK_MODULE_PATH := $(call my-dir)

include $(BUILD_SHARED_LIBRARY)

コンパイラオプションや、リンカオプションを上記の様に設定しないと、ビルドが通っても、学習データ(Protocol Buffersデータ)が正しく読み込めませんでした。

Java側で28x28の手書きのピクセルデータを用意して、JNIでc++側に渡し、グラフデータを元に作成したグラフに入力します。

tensorflow_mnist_screen0.png

無事に認識できました。

学習データの置き換え

上記のモデルだと、認識率が91%くらいなので、TensorFlowの"Deep MNIST for Experts"にあるDeep Learningをつかったモデル(認識率99.2%)に置き換えてみました。

"Deep MNIST for Experts"
https://www.tensorflow.org/versions/master/tutorials/mnist/pros/index.html

学習データ書き出し用スクリプト
https://github.com/miyosuda/TensorFlowAndroidMNIST/blob/master/trainer-script/expert.py

DropOutのノードが入っていると、Android側での実行時に何故かエラーが出てしまったのと、元々DropOutノードは学習時にしか必要ないものなので、書き出し時にはグラフから外しました。

自分の環境(MacBook Pro)で学習には1時間くらいかかりました。

inputノードとoutputノードの名前を同じにしておいたので、書き出し後は、学習データを差し替えるだけで、Android側のコードはすべてそのままで実行できます。

手書き認識を試してみたところ、かなり正確に0〜9の数字を認識するのが確認できました。

ソース

上記のソース一式ははこちら
https://github.com/miyosuda/TensorFlowAndroidMNIST

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
130