Help us understand the problem. What is going on with this article?

Deep Learningアプリケーション開発 (4) TensorFlow Lite with Python

この記事について

機械学習、Deep Learningの専門家ではない人が、Deep Learningを応用したアプリケーションを作れるようになるのが目的です。MNIST数字識別する簡単なアプリケーションを、色々な方法で作ってみます。特に、組み込み向けアプリケーション(Edge AI)を意識しています。
モデルそのものには言及しません。数学的な話も出てきません。Deep Learningモデルをどうやって使うか(エッジ推論)、ということに重点を置いています。

  1. Kerasで簡単にMNIST数字識別モデルを作り、Pythonで確認
  2. TensorFlowモデルに変換してPythonで使用してみる (Windows, Linux)
  3. TensorFlowモデルに変換してCで使用してみる (Windows, Linux)
  4. TensorFlow Liteモデルに変換してPythonで使用してみる (Windows, Linux) <--- 今回の内容
  5. TensorFlow Liteモデルに変換してCで使用してみる (Linux)
  6. TensorFlow Liteモデルに変換してC++で使用してみる (Raspberry Pi)
  7. TensorFlow LiteモデルをEdge TPU上で動かしてみる (Raspberry Pi)

Google Colaboratory版

Google Colaboratory + Tensorflow2.x版を本記事の後ろに追記しました。
Google Colaboratory版

今回の内容

  • KerasモデルをTensorFlow Liteモデルに変換する
  • TensorFlow Lite用モデルを使って、入力画像から数字識別するPythonアプリケーションを作る

ソースコードとサンプル入力画像: https://github.com/take-iwiw/CNN_NumberDetector/tree/master/04_TensorflowLite_Python

環境

  • OS: Windows 10 (64-bt)
  • OS(on VirtualBox): Ubuntu 16.04
  • CPU = Intel Core i7-6700@3.4GHz (物理コア=4、論理プロセッサ数=8)
  • GPU = NVIDIA GeForce GTX 1070 (← GPUは無くても大丈夫です)
  • 開発環境: Anaconda3 64-bit (Python3.6.8)
  • TensorFlow 1.12.0, tf-nightly(1.14)

今回の内容は、WindowsとLinux(Ubuntu)のどちらでも動きますが、本記事の説明はWindowsメインで行います。

KerasモデルをTensorFlow Liteモデルに変換する

ベースとなるモデルは、第1回目にKerasで簡単に作ったものを使用します。(https://qiita.com/take-iwiw/items/796ec8560563625ace34 )。
既存のKeras用モデル(conv_mnist.h5)から、TensorFlow Lite用モデル(conv_mnist.tflite)を作成します。
方法は2つあります。どちらも、tensorflowをcondaやpipでインストールしていれば自動で入っています。(TensorFlow 1.9以降)

  • 変換用コマンドを使用する
  • 変換用のPython APIを使用する

変換方法や、変換コマンドの詳細は以下にまとまっています。TensorFlowモデル(.pb)等の他の形式からの変換方法や、色々なオプション指定の方法が説明されています。
https://www.tensorflow.org/lite/convert

変換用コマンドを使用する

コマンド
tflite_convert --output_file=conv_mnist.tflite --keras_model_file=conv_mnist.h5

上記コマンドを実行するだけです。非常に簡単です。

以下のようなワーニングが出ましたが、特に問題なく動きました。

2019-03-09 15:06:42.549799: I tensorflow/core/common_runtime/process_util.cc:69] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.
WARNING:tensorflow:Error in loading the saved optimizer state. As a result, your model is starting with a freshly initialized optimizer.

変換用のPython APIを使用する

keras_to_tflite.py
# -*- coding: utf-8 -*-
import tensorflow as tf

if __name__ == '__main__':
    # for tensorflow 1.12
    # converter = tf.contrib.lite.TFLiteConverter.from_keras_model_file("conv_mnist.h5")
    # for tensorflow-nightly (1.14)
    converter = tf.lite.TFLiteConverter.from_keras_model_file("conv_mnist.h5")
    tflite_model = converter.convert()
    open("conv_mnist.tflite", "wb").write(tflite_model)

上記のようなスクリプトを実行するだけです。Kerasでモデル作成時に、tflite用モデルも一緒に出力してもいいかもしれません。

ハマった点

Traceback (most recent call last):
  File "C:\Users\tak\Anaconda3\envs\tf_cpu\lib\site-packages\tensorflow\contrib\lite\python\interpreter_wrapper\tensorflow_wrap_interpreter_wrapper.py", line 18, in swig_import_helper
    fp, pathname, description = imp.find_module('_tensorflow_wrap_interpreter_wrapper', [dirname(__file__)])
  File "C:\Users\tak\Anaconda3\envs\tf_cpu\lib\imp.py", line 297, in find_module
    raise ImportError(_ERR_MSG.format(name), name=name)
ImportError: No module named '_tensorflow_wrap_interpreter_wrapper'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "number_detector_tflite.py", line 19, in <module>
    interpreter = tf.contrib.lite.Interpreter(model_path="conv_mnist.tflite")
  File "C:\Users\tak\Anaconda3\envs\tf_cpu\lib\site-packages\tensorflow\contrib\lite\python\interpreter.py", line 52, in __init__
    _interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromFile(
  File "C:\Users\tak\Anaconda3\envs\tf_cpu\lib\site-packages\tensorflow\python\util\lazy_loader.py", line 53, in __getattr__
    module = self._load()
  File "C:\Users\tak\Anaconda3\envs\tf_cpu\lib\site-packages\tensorflow\python\util\lazy_loader.py", line 42, in _load
    module = importlib.import_module(self.__name__)
  File "C:\Users\tak\Anaconda3\envs\tf_cpu\lib\importlib\__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 994, in _gcd_import
  File "<frozen importlib._bootstrap>", line 971, in _find_and_load
  File "<frozen importlib._bootstrap>", line 955, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 665, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 678, in exec_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
  File "C:\Users\tak\Anaconda3\envs\tf_cpu\lib\site-packages\tensorflow\contrib\lite\python\interpreter_wrapper\tensorflow_wrap_interpreter_wrapper.py", line 28, in <module>
    _tensorflow_wrap_interpreter_wrapper = swig_import_helper()
  File "C:\Users\tak\Anaconda3\envs\tf_cpu\lib\site-packages\tensorflow\contrib\lite\python\interpreter_wrapper\tensorflow_wrap_interpreter_wrapper.py", line 20, in swig_import_helper
    import _tensorflow_wrap_interpreter_wrapper
ModuleNotFoundError: No module named '_tensorflow_wrap_interpreter_wrapper'

もしも上記のようなエラーが出る場合は、
conda install tensorflow の代わりに、pip install tf-nightly を試してください。また、コード内に出てくるtensorflow.liteの指定を、tf.contrib.lite. から tf.lite. に変えてみてください。
Linuxだと問題なくできたのですが、Windowsだとtf-nightlyじゃないとダメでした。OSの違いというか、パッケージのバージョンによるものかもしれませんが、詳しく調べてはいません。

TensorFlow Lite用モデルを使って、入力画像から数字識別するCアプリケーションを作る

簡単なサンプルコードがここ(https://www.tensorflow.org/lite/convert/python_api#tensorflow_lite_python_interpreter_ )にあるので、これを参考にPythonコードを書きます。
TensorFlowに比べると、非常にシンプルになっています。

まず最初に、OpenCVを使用して画像入力しています。そして、グレースケール化、28x28にリサイズ、白黒反転、入力Tensorと同じサイズ(1,28,28,1)にリサイズ、0~255を0.0~1.0に正規化、型をfloat32にキャスト、しています。
最後のキャストが非常に重要で、環境に依っては割り算した後に自動的に64-bitになってしまいます。それだと入力TensorのType(32-bit float)と合わないと怒られてしまいます。(エラー: ValueError: Cannot set tensor: Got tensor of type 0 but expected type 1 for input 10)

その後、tfliteモデルをロードして、入出力Tensor情報を取得しています。ここまでは初期処理として一気にやってしまいます。

用意が出来たら、入力Tensorに変換した画像を設定し、run(invoke)します。最後に、出力Tensorから結果を取り出して表示しています。

number_detector_tflite.py
# -*- coding: utf-8 -*-
import cv2
import tensorflow as tf
import numpy as np

if __name__ == '__main__':
    # prepara input image
    img = cv2.imread('resource/4.jpg')
    cv2.imshow('image', img)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    img = cv2.resize(img, (28, 28))
    img = 255 - img
    img = img.reshape(1, img.shape[0], img.shape[1], 1) # (1, 28, 28, 1)
    img = img / 255.
    img = img.astype(np.float32)

    # load model
    # for tensorflow 1.12
    # interpreter = tf.contrib.lite.Interpreter(model_path="conv_mnist.tflite")
    # for tensorflow-nightly (1.14)
    interpreter = tf.lite.Interpreter(model_path="conv_mnist.tflite")
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    # set input tensor
    interpreter.set_tensor(input_details[0]['index'], img)

    # run
    interpreter.invoke()

    # get outpu tensor
    probs = interpreter.get_tensor(output_details[0]['index'])

    # print result
    result = np.argmax(probs[0])
    score = probs[0][result]
    print("predicted number is {} [{:.2f}]".format(result, score))

    cv2.waitKey(0)
    cv2.destroyAllWindows()
実行結果
predicted number is 4 [0.93]

Google Colaboratoryで試す

本記事の内容を、Google Colaboratoryで試します。
使用するTensorflowのバージョンは、2.1.0-rc1でした。

https://github.com/iwatake2222/colaboratory_study/blob/master/DL_tutorial/DL_tutorial_04.ipynb

モデル変換 (Keras(H5) -> Tensorflow Lite Model(tflite))

Kerasモデルからtfliteモデルに変換してみます。

keras2tflite
%tensorflow_version 2.x
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf

print(tf.__version__)

# Download Keraas model
!wget -O conv_mnist.h5 "https://drive.google.com/uc?export=download&id=1OLR1n5Pq0CGPz7Bw5pad-fvYsgCnKvHh" 

# Convert
loaded_model = tf.keras.models.load_model("conv_mnist.h5")
converter = tf.lite.TFLiteConverter.from_keras_model(loaded_model)

tflite_model = converter.convert()
open("conv_mnist.tflite", "wb").write(tflite_model)

モデル変換 (Tensorflow Saved Model(pb) -> Tensorflow Lite Model(tflite))

Tensorflow saved modelからtfliteモデルに変換してみます。

tf2tflite
%tensorflow_version 2.x
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf

print(tf.__version__)

# Download Saved model
!wget -O conv_mnist_saved_model.tar.gz "https://drive.google.com/uc?export=download&id=1T0_2UYERZkTQnBZ4ocfnT7j1u0XJ2xKm"
!tar zxvf conv_mnist_saved_model.tar.gz
saved_model_dir = "./conv_mnist_saved_model"

# Convert
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)

tflite_model = converter.convert()
open("conv_mnist_from_pb.tflite", "wb").write(tflite_model)

TFLiteモデルによる推論テスト

tflite形式に変換したモデルを使って推論してみます。

TFLiteモデルによる推論テスト
import cv2
from google.colab.patches import cv2_imshow
import numpy as np

# Read input image
!wget -O 4.jpg  "https://drive.google.com/uc?export=download&id=1-3yb3qCrN8M6Bdj7ZZ9UMjONh34R2W_A" 
img = cv2.imread("4.jpg")
cv2_imshow(img)

# Pre process
## グレースケール化、リサイズ、白黒判定、価範囲を0~255 -> 0.0~1.0
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img = cv2.resize(img, (28, 28))
img = 255 - img
img = img / 255.
img = img.astype(np.float32)
input_tensor = img.reshape(1, img.shape[0], img.shape[1], 1)
input_tensor = tf.convert_to_tensor(input_tensor)

# Load model
interpreter = tf.lite.Interpreter(model_path="conv_mnist.tflite")
# interpreter = tf.lite.Interpreter(model_path="conv_mnist_from_pb.tflite")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# set input tensor
interpreter.set_tensor(input_details[0]['index'], input_tensor)

# Inference
interpreter.invoke()

scores = interpreter.get_tensor(output_details[0]['index'])
result = np.argmax(scores[0])
print("predicted number is {} [{:.2f}]".format(result, scores[0][result]))
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away