LoginSignup
11
5

More than 3 years have passed since last update.

Native build of Tensorflow v2.0.0-beta for Raspberry Pi3 (armv7l)

Last updated at Posted at 2019-06-10

Tensorflow-bin GitHub stars

Bazel_bin GitHub stars

1.Introduction

beta版でのアンオフィシャルなコンパイル格闘ログなので、あまり過剰な期待はしないでください。 過去のalpha版でのコンパイル記事は RaspberryPi3用のTensorflow v2.0.0-alpha (Tensorflow Lite v1.0) のインストーラ(Wheel)を速攻でネイティブビルド錬成した です。 当時との手順の差異は、Bazelのビルド手順とバージョンのみです。 RaspberryPi3上でARM版Bazelのビルドを通すのに一日費やしました。 コンパイル済みバイナリの利用上の制約事項を記載していますので、以下に続く Introduction の章をしっかりと読んでいただくことを推奨いたします。 ただ、英文は読まなくてもコピペで同じことができます。 誰でもできる単調な作業ですが、丸2日間コンパイルし続ける異常な胆力がある方のみチャレンジしてください。 余談ですが、前回alphaをビルドしたときが私の誕生日でしたので、公式がalphaからbetaへ移行するまで実に3ヶ月掛かったことになります。

I created a Wheel package for RaspberryPi3 of Tensorflow v2.0.0-beta published on June 8, 2019. I have not confirmed the operation of every OP, but I hope you find it useful. Btw, the fully compiled Wheel file for RaspberryPi3 has been saved in the above Github repository (Tensorflow-bin).

The performance of Tensorflow Lite, created according to my procedure, is about 2.5 times faster than the official binary when multithreading is enabled. The point to note, however, is that not all models get equal 2.5 times performance. It is necessary to consider that the other layers are not accelerated since only the convolutional layer is subject to multi-thread parallel processing. See the following link for a discussion of Tensorflow Lite acceleration by engineers around the world and I: Tensorflow Lite, python API does not work #21574. And, the binary I generated excludes matrix_square_root_op from the compilation target in order to avoid the memory shortage of RaspberryPi3 at compile time. The matrix_square_root_op wastes more than 3GB in total: 2GB of swap area and 1GB of physical memory.

For the benefits of Tensorflow v2.0.0-beta, the information on the links below may be helpful.

Announcing TensorFlow 2.0 Beta
FireShot Capture 041 - Announcing TensorFlow 2.0 Beta – TensorFlow – Medium - medium.com.png
TensorFlow Core - Overview
FireShot Capture 042 - TensorFlow Core  -  TensorFlow - www.tensorflow.org.png
TensorFlow Core - Tutorials
FireShot Capture 043 - TensorFlow Core  -  TensorFlow - www.tensorflow.org.png
TensorFlow Core - Guide
FireShot Capture 044 - TensorFlow Guide  -  TensorFlow Core  -  TensorFlow - www.tensorflow.org.png
TensorFlow Core - TensorFlow 2.0 Beta
FireShot Capture 045 - TensorFlow Core  -  TensorFlow - www.tensorflow.org.png

The cross compilation steps listed on the official site often do not succeed. Even if you succeed, you may notice that the generated binary has various problems. The reason I stick to insane native builds is that I don't want to generate broken binaries.

2.Environment

  • RaspberryPi3 model B+ (armhf/armv7l/Raspbian Stretch)
  • Tensorflow v2.0.0-beta0
  • Bazel 0.24.1

3.Procedure

3−1.Preparation before compilation

Prepare to build Tensorflow v2.0.0-beta.

Preparation_before_compilation
$ sudo nano /etc/dphys-swapfile
CONF_SWAPFILE=2048
CONF_MAXSWAP=2048

$ sudo systemctl stop dphys-swapfile
$ sudo systemctl start dphys-swapfile

$ wget https://github.com/PINTO0309/Tensorflow-bin/raw/master/zram.sh
$ chmod 755 zram.sh
$ sudo mv zram.sh /etc/init.d/
$ sudo update-rc.d zram.sh defaults
$ sudo reboot

$ sudo apt-get install -y libhdf5-dev libc-ares-dev libeigen3-dev
$ sudo pip3 install keras_applications==1.0.7 --no-deps
$ sudo pip3 install keras_preprocessing==1.0.9 --no-deps
$ sudo pip3 install h5py==2.9.0
$ sudo apt-get install -y openmpi-bin libopenmpi-dev
$ sudo -H pip3 install -U --user six numpy wheel mock
$ sudo apt update;sudo apt upgrade

### For Stretch
$ sudo apt-get install -y openjdk-8-jdk
or
### For Buster
$ sudo apt-get install -y openjdk-11-jdk

$ cd ~
$ git clone https://github.com/PINTO0309/Bazel_bin.git
$ cd Bazel_bin

### For Stretch
$ ./0.24.1/Raspbian_Stretch_armhf/install.sh
or
### For Buster
$ ./0.24.1/Raspbian_Buster_armhf/install.sh

$ cd ~
$ git clone -b v2.0.0-beta0 https://github.com/tensorflow/tensorflow.git
$ cd tensorflow
$ git checkout -b v2.0.0-beta0

3−2.Edit various files

Add MultiThread function to Tensorflow Lite's Python API. This means customizing Tensorflow Lite's implementation on your own. [tflite] export SetNumThreads to TFLite Python API #25748

tensorflow/lite/python/interpreter.py
# Add the following two lines to the last line
  def set_num_threads(self, i):
    return self._interpreter.SetNumThreads(i)
tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
// Corrected the vicinity of the last line as follows
PyObject* InterpreterWrapper::ResetVariableTensors() {
  TFLITE_PY_ENSURE_VALID_INTERPRETER();
  TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
  Py_RETURN_NONE;
}

PyObject* InterpreterWrapper::SetNumThreads(int i) {
  interpreter_->SetNumThreads(i);
  Py_RETURN_NONE;
}

}  // namespace interpreter_wrapper
}  // namespace tflite
tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h
  // should be the interpreter object providing the memory.
  PyObject* tensor(PyObject* base_object, int i);

  PyObject* SetNumThreads(int i);

 private:
  // Helper function to construct an `InterpreterWrapper` object.
  // It only returns InterpreterWrapper if it can construct an `Interpreter`.

Disable compilation of matrix_square_root_op.

tensorflow/tensorflow/core/kernels/BUILD
cc_library(
    name = "linalg",
    deps = [
        ":cholesky_grad",
        ":cholesky_op",
        ":determinant_op",
        ":lu_op",
        ":matrix_exponential_op",
        ":matrix_inverse_op",
        ":matrix_logarithm_op",
        ":matrix_solve_ls_op",
        ":matrix_solve_op",
        ":matrix_triangular_solve_op",
        ":qr_op",
        ":self_adjoint_eig_op",
        ":self_adjoint_eig_v2_op",
        ":svd_op",
        ":tridiagonal_solve_op",
    ],
)
tensorflow/tensorflow/core/kernels/BUILD_(Delete_the_following)
tf_kernel_library(
    name = "matrix_square_root_op",
    prefix = "matrix_square_root_op",
    deps = LINALG_DEPS,
)

Disable NNAPI.

tensorflow/lite/tools/make/Makefile
BUILD_WITH_NNAPI=false

3−3.Configure of Tensorflow v2.0.0-beta

Set build parameters of Tensorflow v2.0.0-beta.

Procedure_of_configure
$ cd ~/tensorflow
$ ./configure

Extracting Bazel installation...
WARNING: --batch mode is deprecated. Please instead explicitly shut down your Bazel server using the command "bazel shutdown".
You have bazel 0.24.1- (@non-git) installed.
Please specify the location of python. [Default is /usr/bin/python]: /usr/bin/python3


Found possible Python library paths:
  /usr/local/lib
  /usr/lib/python3/dist-packages
  /home/pi/inference_engine_vpu_arm/python/python3.5
  /usr/local/lib/python3.5/dist-packages
Please input the desired Python library path to use.  Default is [/usr/local/lib]
/usr/local/lib/python3.5/dist-packages
Do you wish to build TensorFlow with XLA JIT support? [Y/n]: n
No XLA JIT support will be enabled for TensorFlow.

Do you wish to build TensorFlow with OpenCL SYCL support? [y/N]: n
No OpenCL SYCL support will be enabled for TensorFlow.

Do you wish to build TensorFlow with ROCm support? [y/N]: n
No ROCm support will be enabled for TensorFlow.

Do you wish to build TensorFlow with CUDA support? [y/N]: n
No CUDA support will be enabled for TensorFlow.

Do you wish to download a fresh release of clang? (Experimental) [y/N]: n
Clang will not be downloaded.

Do you wish to build TensorFlow with MPI support? [y/N]: n
No MPI support will be enabled for TensorFlow.

Please specify optimization flags to use during compilation when bazel option "--config=opt" is specified [Default is -march=native -Wno-sign-compare]: 


Would you like to interactively configure ./WORKSPACE for Android builds? [y/N]: n
Not configuring the WORKSPACE for Android builds.

Preconfigured Bazel build configs. You can use any of the below by adding "--config=<>" to your build command. See .bazelrc for more details.
    --config=mkl            # Build with MKL support.
    --config=monolithic     # Config for mostly static monolithic build.
    --config=gdr            # Build with GDR support.
    --config=verbs          # Build with libverbs support.
    --config=ngraph         # Build with Intel nGraph support.
    --config=numa           # Build with NUMA support.
    --config=dynamic_kernels    # (Experimental) Build kernels into separate shared objects.
Preconfigured Bazel build configs to DISABLE default on features:
    --config=noaws          # Disable AWS S3 filesystem support.
    --config=nogcp          # Disable GCP support.
    --config=nohdfs         # Disable HDFS support.
    --config=noignite       # Disable Apache Ignite support.
    --config=nokafka        # Disable Apache Kafka support.
    --config=nonccl         # Disable NVIDIA NCCL support.
Configuration finished

3−4.Build Tensorflow v2.0.0-beta

Build_command_by_Bazel_0.24.1
$ sudo bazel --host_jvm_args=-Xmx512m build \
--config=opt \
--config=noaws \
--config=nogcp \
--config=nohdfs \
--config=noignite \
--config=nokafka \
--config=nonccl \
--local_resources=1024.0,0.5,0.5 \
--copt=-mfpu=neon-vfpv4 \
--copt=-ftree-vectorize \
--copt=-funsafe-math-optimizations \
--copt=-ftree-loop-vectorize \
--copt=-fomit-frame-pointer \
--copt=-DRASPBERRY_PI \
--host_copt=-DRASPBERRY_PI \
//tensorflow/tools/pip_package:build_pip_package

3−5.Create Wheel file

Command_to_create_Wheel_file
$ su --preserve-environment
# ./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
# exit
$ sudo cp /tmp/tensorflow_pkg/tensorflow-2.0.0b0-cp35-cp35m-linux_arm7l.whl ~

3−6.Installation of Tensorflow v2.0.0-beta

Installation_command
$ cd ~
$ sudo pip3 uninstall tensorflow
$ sudo -H pip3 install tensorflow-2.0.0b0-cp35-cp35m-linux_armv7l.whl 

4.Operation check

Command_to_check_the_installed_Tensorflow_version
$ python3 -c 'import tensorflow as tf; print(tf.__version__)'

5.Performance test of MultiThread with MobileNetV2

$ cd ~;mkdir test
$ curl https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/lite/examples/label_image/testdata/grace_hopper.bmp > ~/test/grace_hopper.bmp
$ curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz | tar xzv -C ~/test mobilenet_v1_1.0_224/labels.txt
$ mv ~/test/mobilenet_v1_1.0_224/labels.txt ~/test/
$ curl http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224_quant.tgz | tar xzv -C ~/test
$ cp tensorflow/tensorflow/contrib/lite/examples/python/label_image.py ~/test

Edit label_image.py.

[Sample Code] label_image.py
import argparse
import numpy as np
import time

from PIL import Image

# Tensorflow v2.x.x
from tensorflow.lite.python import interpreter as interpreter_wrapper

def load_labels(filename):
  my_labels = []
  input_file = open(filename, 'r')
  for l in input_file:
    my_labels.append(l.strip())
  return my_labels
if __name__ == "__main__":
  floating_model = False
  parser = argparse.ArgumentParser()
  parser.add_argument("-i", "--image", default="/tmp/grace_hopper.bmp", \
    help="image to be classified")
  parser.add_argument("-m", "--model_file", \
    default="/tmp/mobilenet_v1_1.0_224_quant.tflite", \
    help=".tflite model to be executed")
  parser.add_argument("-l", "--label_file", default="/tmp/labels.txt", \
    help="name of file containing labels")
  parser.add_argument("--input_mean", default=127.5, help="input_mean")
  parser.add_argument("--input_std", default=127.5, \
    help="input standard deviation")
  parser.add_argument("--num_threads", default=1, help="number of threads")
  args = parser.parse_args()

  interpreter = interpreter_wrapper.Interpreter(model_path=args.model_file)
  interpreter.allocate_tensors()
  input_details = interpreter.get_input_details()
  output_details = interpreter.get_output_details()
  # check the type of the input tensor
  if input_details[0]['dtype'] == np.float32:
    floating_model = True
  # NxHxWxC, H:1, W:2
  height = input_details[0]['shape'][1]
  width = input_details[0]['shape'][2]
  img = Image.open(args.image)
  img = img.resize((width, height))
  # add N dim
  input_data = np.expand_dims(img, axis=0)
  if floating_model:
    input_data = (np.float32(input_data) - args.input_mean) / args.input_std

  interpreter.set_num_threads(int(args.num_threads)) #<--- You need to add this line.
  interpreter.set_tensor(input_details[0]['index'], input_data)

  start_time = time.time()
  interpreter.invoke()
  stop_time = time.time()

  output_data = interpreter.get_tensor(output_details[0]['index'])
  results = np.squeeze(output_data)
  top_k = results.argsort()[-5:][::-1]
  labels = load_labels(args.label_file)
  for i in top_k:
    if floating_model:
      print('{0:08.6f}'.format(float(results[i]))+":", labels[i])
    else:
      print('{0:08.6f}'.format(float(results[i]/255.0))+":", labels[i])

  print("time: ", stop_time - start_time)

Run test.

1_thread_execution_sample
$ cd ~/test
$ python3 label_image.py \
--num_threads 1 \
--image grace_hopper.bmp \
--model_file mobilenet_v1_1.0_224_quant.tflite \
--label_file labels.txt

0.415686: 653:military uniform
0.352941: 907:Windsor tie
0.058824: 668:mortarboard
0.035294: 458:bow tie, bow-tie, bowtie
0.035294: 835:suit, suit of clothes
time:  0.4152982234954834
4_thread_execution_sample
$ cd ~/test
$ python3 label_image.py \
--num_threads 4 \
--image grace_hopper.bmp \
--model_file mobilenet_v1_1.0_224_quant.tflite \
--label_file labels.txt

0.415686: 653:military uniform
0.352941: 907:Windsor tie
0.058824: 668:mortarboard
0.035294: 458:bow tie, bow-tie, bowtie
0.035294: 835:suit, suit of clothes
time:  0.1647195816040039

6.Reference articles

Compiling Bazel from source - Build Bazel from scratch (bootstrapping) - 2. Bootstrap Bazel on Ubuntu Linux, macOS, and other Unix-like systems

Converter Python API guide

11
5
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
5