LoginSignup
17
27

More than 3 years have passed since last update.

Deep Learningアプリケーション開発 (3) TensorFlow with C

Last updated at Posted at 2019-03-08

この記事について

機械学習、Deep Learningの専門家ではない人が、Deep Learningを応用したアプリケーションを作れるようになるのが目的です。MNIST数字識別する簡単なアプリケーションを、色々な方法で作ってみます。特に、組み込み向けアプリケーション(Edge AI)を意識しています。
モデルそのものには言及しません。数学的な話も出てきません。Deep Learningモデルをどうやって使うか(エッジ推論)、ということに重点を置いています。

  1. Kerasで簡単にMNIST数字識別モデルを作り、Pythonで確認
  2. TensorFlowモデルに変換してPythonで使用してみる (Windows, Linux)
  3. TensorFlowモデルに変換してCで使用してみる (Windows, Linux) <--- 今回の内容
  4. TensorFlow Liteモデルに変換してPythonで使用してみる (Windows, Linux)
  5. TensorFlow Liteモデルに変換してCで使用してみる (Linux)
  6. TensorFlow Liteモデルに変換してC++で使用してみる (Raspberry Pi)
  7. TensorFlow LiteモデルをEdge TPU上で動かしてみる (Raspberry Pi)

今回の内容

  • TensorFlow for C ライブラリの用意をする
  • プロジェクトを用意する
  • TensorFlow用モデルを使って、入力画像から数字識別するCアプリケーションを作る
    • モデルは前回作成済みのものを使用

ソースコードとサンプル入力画像: https://github.com/take-iwiw/CNN_NumberDetector/tree/master/03_Tensorflow_C

環境

  • OS: Windows 10 (64-bt)
  • OS(on VirtualBox): Ubuntu 16.04
  • CPU = Intel Core i7-6700@3.4GHz (物理コア=4、論理プロセッサ数=8)
  • GPU = NVIDIA GeForce GTX 1070 (← GPUは無くても大丈夫です)
  • 開発環境(Windows): Visual Studio Community 2017, cmake-gui
  • 開発環境(Linux): cmake, gcc
  • TensorFlow 1.12.0
  • パッケージ詳細はこちら Windows用Linux用

今回の内容は、WindowsとLinux(Ubuntu)のどちらでも動きますが、本記事の説明はWindowsメインで行います。

TensorFlow for C Libraryの用意をする

https://www.tensorflow.org/install/lang_c が情報がまとまっているページです。
ライブラリの用意方法は2つあります。

  1. ビルド済みのライブラリをダウンロードする
  2. 自分でビルドする

ラズパイ等の別の環境でも試したい方、カスタマイズしたい方は自分でビルドする必要があります。
カスタマイズ例としては、拡張命令の使用指定があります。ビルド済みのライブラリを使用すると、環境に依っては I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 といったワーニングが出ます。せっかく、CPUが拡張命令をサポートしているのに、TensorFlowがそれを使うようにコンパイルされてないよ、と言っています。実はこの警告はPython版で実行しても出ます。恐らく、ビルド済み版やパッケージ版では汎用性を重視しているのかなと思います。

今回は、1. ビルド済みのライブラリをダウンロードして使います。
2019年3月8現在、最新の以下2つのファイルをダウンロードします。

これで、Linux用の共有ライブラリ(libtensorflow.so)とWindows用の共有ライブラリ(tensorflow.dll)が手に入りました。Windowsの方はこのままだとビルドに使用できないのでtensorflow.lib を作成する必要があります。手順は、https://qiita.com/take-iwiw/items/3caf294726918c3af9e6 をご参照ください。

プロジェクトを用意する

WindowsとLinuxのどちらにも対応できる、クロスプラットフォームなプロジェクトにしたいと思います。そのためにCMakeを使います。

以下のようなディレクトリ構成を用意します。(https://github.com/take-iwiw/CNN_NumberDetector/tree/master/03_Tensorflow_C もご参考にしてください)

プロジェクト構成
├─libtensorflow/
│  ├─libtensorflow-cpu-linux-x86_64-1.12.0/   <- ダウンロードしたものを展開
│  │  ├─include/tensorflow/c/c_api.h
│  │  └─lib/libtensorflow.so, libtensorflow_framework.so
│  └─libtensorflow-cpu-windows-x86_64-1.12.0/  <- ダウンロードしたものを展開
│      ├─include/tensorflow/c/c_api.h
│      └─lib/tensorflow.dll, tensorflow.lib
├─resource/
│  ├─0.jpg            <- サンプル画像
│  └─conv_mnist.pb    <- 前回作成した、数字識別用モデル (Keras用モデルからTensorFlow用モデルに変換済み)
├─CMakeLists.txt
├─main.cpp
└─tf_utils.cpp,h

必要なファイル

  • libtensorflow
    • ダウンロードしたTensorFlow for Cライブラリと、作成したtensorflow.libを配置します。
  • resource
    • サンプル画像とモデルを配置します。実行ディレクトリにコピーします。実行時に使われます。
  • CMakeLists.txt
    • CMakeプロジェクト設定を記載します (後述)
  • main.cpp
    • TensorFlow for Cを使って数字識別するアプリケーションのソースコードです (後述)
  • tf_utils.cpp, tf_utils.h
    • TensorFlow for Cを使いやすくしてくれるユーティリティクラスです (MIT License)
    • https://github.com/Neargye/hello_tf_c_api
    • 本ソースコードもこのページを大変参考にさせていただきました

CMakeLists.txt

以下のようなCMakeLists.txtを作ります。

あらかじめ環境変数にOpenCV_DIRを設定しておいてください。
WindowsかLinuxかの環境に合わせて適切なライブラリをリンクします。また、実行バイナリファイルと同じ場所にコピーします。
必要なリソースファイル(サンプル画像とモデルファイル)も実行バイナリファイルと同じ場所にコピーします。

CMakeLists.txt
cmake_minimum_required(VERSION 2.8)
project(NumberDetector)

# Create Main project
add_executable(NumberDetector
    main.cpp
    tf_utils.cpp
    tf_utils.hpp
)

# For OpenCV
find_package(OpenCV REQUIRED)
if(OpenCV_FOUND)
    target_include_directories(NumberDetector PUBLIC ${OpenCV_INCLUDE_DIRS})
    target_link_libraries(NumberDetector ${OpenCV_LIBS})
endif()

# For Tensorflow
if(WIN32)
    target_link_libraries(NumberDetector ${PROJECT_SOURCE_DIR}/libtensorflow/libtensorflow-cpu-windows-x86_64-1.12.0/lib/tensorflow.lib)
    target_include_directories(NumberDetector PUBLIC ${PROJECT_SOURCE_DIR}/libtensorflow/libtensorflow-cpu-windows-x86_64-1.12.0/include)
    file(COPY ${PROJECT_SOURCE_DIR}/libtensorflow/libtensorflow-cpu-windows-x86_64-1.12.0/lib/tensorflow.dll DESTINATION ${PROJECT_BINARY_DIR})
else()
    target_link_libraries(NumberDetector ${PROJECT_SOURCE_DIR}/libtensorflow/libtensorflow-cpu-linux-x86_64-1.12.0/lib/libtensorflow.so)
    target_include_directories(NumberDetector PUBLIC ${PROJECT_SOURCE_DIR}/libtensorflow/libtensorflow-cpu-linux-x86_64-1.12.0/include)
    file(COPY ${PROJECT_SOURCE_DIR}/libtensorflow/libtensorflow-cpu-linux-x86_64-1.12.0/lib/libtensorflow.so DESTINATION ${PROJECT_BINARY_DIR})
    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -lstdc++")
endif()

# Copy resouce
file(COPY ${CMAKE_SOURCE_DIR}/resource/ DESTINATION ${PROJECT_BINARY_DIR}/resource/)
add_definitions(-DRESOURCE_DIR="${PROJECT_BINARY_DIR}/resource/")

TensorFlow用モデルを使って、入力画像から数字識別するCアプリケーションを作る

現時点では、残念ながらTensorFlow for CのAPIリファレンスマニュアルがありません。。。
C++用は https://www.tensorflow.org/versions/r1.12/api_docs/cc にあるのですが、C用はありません。(そのわりにC++用ビルド済みライブラリの配布はしていない。。。)

先ほどの、https://github.com/Neargye/hello_tf_c_api のページを参考に実装してみます。また、TensorFlow for CのAPIそのままだとモデルのロードが面倒なのですが、tf_utils.cppで良い感じにwrapしてくれています。

main.cpp
#include <stdio.h>
#include <opencv2/opencv.hpp>
#include <tensorflow/c/c_api.h>
#include "tf_utils.hpp"

#define MODEL_FILENAME RESOURCE_DIR"conv_mnist.pb"

static int displayGraphInfo()
{
    TF_Graph *graph = tf_utils::LoadGraphDef(MODEL_FILENAME);
    if (graph == nullptr) {
        std::cout << "Can't load graph" << std::endl;
        return 1;
    }

    size_t pos = 0;
    TF_Operation* oper;
    printf("--- graph info ---\n");
    while ((oper = TF_GraphNextOperation(graph, &pos)) != nullptr) {
        printf("%s\n", TF_OperationName(oper));
    }
    printf("--- graph info ---\n");

    TF_DeleteGraph(graph);
    return 0;
}

int main()
{
    printf("Hello from TensorFlow C library version %s\n", TF_Version());

    /* read input image data */
    cv::Mat image = cv::imread(RESOURCE_DIR"4.jpg");
    cv::imshow("InputImage", image);

    /* convert to 28 x 28 grayscale image (normalized: 0 ~ 1.0) */
    cv::cvtColor(image, image, CV_BGR2GRAY);
    cv::resize(image, image, cv::Size(28, 28));
    image = ~image;
    cv::imshow("InputImage for CNN", image);
    image.convertTo(image, CV_32FC1, 1.0 / 255);

    /* get graph info */
    displayGraphInfo();

    TF_Graph *graph = tf_utils::LoadGraphDef(MODEL_FILENAME);
    if (graph == nullptr) {
        std::cout << "Can't load graph" << std::endl;
        return 1;
    }

    /* prepare input tensor */
    TF_Output input_op = { TF_GraphOperationByName(graph, "input_1"), 0 };
    if (input_op.oper == nullptr) {
        std::cout << "Can't init input_op" << std::endl;
        return 2;
    }

    const std::vector<std::int64_t> input_dims = { 1, 28, 28, 1 };
    std::vector<float> input_vals;
    image.reshape(0, 1).copyTo(input_vals); // Mat to vector

    TF_Tensor* input_tensor = tf_utils::CreateTensor(TF_FLOAT,
        input_dims.data(), input_dims.size(),
        input_vals.data(), input_vals.size() * sizeof(float));

    /* prepare output tensor */
    TF_Output out_op = { TF_GraphOperationByName(graph, "dense_1/Softmax"), 0 };
    if (out_op.oper == nullptr) {
        std::cout << "Can't init out_op" << std::endl;
        return 3;
    }

    TF_Tensor* output_tensor = nullptr;

    /* prepare session */
    TF_Status* status = TF_NewStatus();
    TF_SessionOptions* options = TF_NewSessionOptions();
    TF_Session* sess = TF_NewSession(graph, options, status);
    TF_DeleteSessionOptions(options);

    if (TF_GetCode(status) != TF_OK) {
        TF_DeleteStatus(status);
        return 4;
    }

    /* run session */
    TF_SessionRun(sess,
        nullptr, // Run options.
        &input_op, &input_tensor, 1, // Input tensors, input tensor values, number of inputs.
        &out_op, &output_tensor, 1, // Output tensors, output tensor values, number of outputs.
        nullptr, 0, // Target operations, number of targets.
        nullptr, // Run metadata.
        status // Output status.
    );

    if (TF_GetCode(status) != TF_OK) {
        std::cout << "Error run session";
        TF_DeleteStatus(status);
        return 5;
    }

    TF_CloseSession(sess, status);
    if (TF_GetCode(status) != TF_OK) {
        std::cout << "Error close session";
        TF_DeleteStatus(status);
        return 6;
    }

    TF_DeleteSession(sess, status);
    if (TF_GetCode(status) != TF_OK) {
        std::cout << "Error delete session";
        TF_DeleteStatus(status);
        return 7;
    }

    const auto probs = static_cast<float*>(TF_TensorData(output_tensor));

    for (int i = 0; i < 10; i++) {
        printf("prob of %d: %.3f\n", i, probs[i]);
    }

    TF_DeleteTensor(input_tensor);
    TF_DeleteTensor(output_tensor);
    TF_DeleteGraph(graph);
    TF_DeleteStatus(status);

    cv::waitKey(0);
    return 0;
}

displayGraphInfo() という関数で、モデル情報を表示しています。これは、実際のアプリケーションでは不要です。ですが、入出力Tensorの名前を取得するのに必要です。前回、Keras用モデルからTensorFlow用モデルに変換する際にもPythonスクリプトで同様のことをやりましたが、どうも名前が微妙に異なるようです(先頭のimport/と最後の:0が消えるだけかな)。

プログラムでは、まずOpenCVを使って画像を読み込み、28x28にリサイズ、白黒反転し、0~255を0.0~1.0に変換しています。
その後、tf_utils::LoadGraphDef()でモデルを読み込んでいます。
事前にdisplayGraphInfo()で調べた入出力Tensor名(input_1, dense_1/Softmax <- この名前は環境によって変わる可能性有り)を使って、入出力Tensorを取得しています。
入力Tensorには、入力データが必要なので、OpenCV::Matを1次元にreshapeして、vector型にコピーしています。このvectorのサイズは、モデルの入力Tensorのサイズと同じである必要があります。今回の場合だと、(1, 28, 28, 1)です。
そして、TensorFlow sessionの用意をします。ここは詳しくは見ていませんが、お決まりの処理なんだと思います。ここではTensorFlowの関数を直接呼んでいますが、tfutils.cppではこれらの処理をまとめた関数を用意しています。
セッションのrunが完了して、エラーが無かったら出力Tensorから出力値を取得しています。

実行結果
Hello from TensorFlow C library version 1.12.0
--- graph info ---
input_1
conv2d_1/kernel
conv2d_1/bias
conv2d_1/Conv2D/ReadVariableOp
conv2d_1/Conv2D
conv2d_1/BiasAdd/ReadVariableOp
conv2d_1/BiasAdd
conv2d_1/Relu
max_pooling2d_1/MaxPool
conv2d_2/kernel
conv2d_2/bias
conv2d_2/Conv2D/ReadVariableOp
conv2d_2/Conv2D
conv2d_2/BiasAdd/ReadVariableOp
conv2d_2/BiasAdd
conv2d_2/Relu
dropout_1/keras_learning_phase/input
dropout_1/keras_learning_phase
dropout_1/cond/Switch
dropout_1/cond/switch_t
dropout_1/cond/pred_id
dropout_1/cond/dropout/keep_prob
dropout_1/cond/dropout/Shape/Switch
dropout_1/cond/dropout/Shape
dropout_1/cond/dropout/random_uniform/min
dropout_1/cond/dropout/random_uniform/max
dropout_1/cond/dropout/random_uniform/RandomUniform
dropout_1/cond/dropout/random_uniform/sub
dropout_1/cond/dropout/random_uniform/mul
dropout_1/cond/dropout/random_uniform
dropout_1/cond/dropout/add
dropout_1/cond/dropout/Floor
dropout_1/cond/dropout/div
dropout_1/cond/dropout/mul
dropout_1/cond/Identity/Switch
dropout_1/cond/Identity
dropout_1/cond/Merge
flatten_1/Shape
flatten_1/strided_slice/stack
flatten_1/strided_slice/stack_1
flatten_1/strided_slice/stack_2
flatten_1/strided_slice
flatten_1/Reshape/shape/1
flatten_1/Reshape/shape
flatten_1/Reshape
dense_1/kernel
dense_1/bias
dense_1/MatMul/ReadVariableOp
dense_1/MatMul
dense_1/BiasAdd/ReadVariableOp
dense_1/BiasAdd
dense_1/Softmax
Adam/iterations
Adam/lr
Adam/beta_1
Adam/beta_2
Adam/decay
training/Adam/Variable
training/Adam/Variable_1
training/Adam/Variable_2
training/Adam/Variable_3
training/Adam/Variable_4
training/Adam/Variable_5
training/Adam/Variable_6
training/Adam/Variable_7
training/Adam/Variable_8
training/Adam/Variable_9
training/Adam/Variable_10
training/Adam/Variable_11
training/Adam/Variable_12
training/Adam/Variable_13
training/Adam/Variable_14
training/Adam/Variable_15
training/Adam/Variable_16
training/Adam/Variable_17
--- graph info ---
2019-03-09 00:17:09.353952: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
prob of 0: 0.000
prob of 1: 0.001
prob of 2: 0.013
prob of 3: 0.043
prob of 4: 0.929
prob of 5: 0.003
prob of 6: 0.001
prob of 7: 0.006
prob of 8: 0.001
prob of 9: 0.003

はまったポイント

  • Linuxでビルドするときには、-std=c++11 -lstdc++オプションが必要だった
  • Linuxでビルド、実行するとき、シェル起動時にAnacondaをデフォルトで有効にしてあると、実行時に以下のようなエラーが出た。~/.bashrcからconda initに関する記載をコメントアウトする必要があった。WindowsでもAnacondaをパスに追加したら問題あるかもしれない。
エラー
(base) ~/Desktop/win_share/03_Tensorflow_C/build$ ./NumberDetector
[libprotobuf FATAL google/protobuf/stubs/common.cc:68] This program requires version 3.6.0 of the Protocol Buffer runtime library, but the installed version is 3.5.1.  Please update your library.  If you compiled the program yourself, make sure that your headers are from the same version of Protocol Buffers as your link-time library.  (Version verification failed in "external/protobuf_archive/src/google/protobuf/any.pb.cc".)
terminate called after throwing an instance of 'google::protobuf::FatalException'
  what():  This program requires version 3.6.0 of the Protocol Buffer runtime library, but the installed version is 3.5.1.  Please update your library.  If you compiled the program yourself, make sure that your headers are from the same version of Protocol Buffers as your link-time library.  (Version verification failed in "external/protobuf_archive/src/google/protobuf/any.pb.cc".)
Aborted (core dumped)
17
27
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
17
27