25
34

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

C++ で TensorFlow の推論処理を実行する

Posted at

C++ で TensorFlow の推論処理を実行する

はじめに

前回 は Python で作成したモデルに対して C++ で学習を行うところまで説明しました。
今回は,学習したモデルの freeze と推論処理の実行を行いたいと思います。

今回作成したコードは前回同様 github に置いてあるので、詳細はこちらをご確認ください。

実行環境

(前回と同じです)

  • Windows
  • Python 3.6
  • keras 2.2.4
  • tensorflow 1.10.0
  • Visual Studio 2015

なお、Python と C++ で使用する TensorFlow のバージョンは揃えていないとエラーが発生する場合があるようです。

C++ 用の tensorflow.dll はこちらのサイトからダウンロードしました。

基本的な流れ

  1. モデルの freeze を行う (Python)
  2. freeze されたモデルを使用して推論を行う (C++)

モデルの freeze を行う

前回作成した model.meta と checkpoint を読み込み、 tf.graph_util.convert_variables_to_constants() を使用してモデルの freeze を行います。

output_node_names = ['output/Softmax']

with tf.Session() as sess:
    # Restore the graph
    saver = tf.train.import_meta_graph("model.meta")

    # Load weights
    saver.restore(sess, './checkpoints/model.ckpt')

    # Freeze the graph
    frozen_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph_def,
        output_node_names
        )

    # Save the frozen graph
    with open('frozen_graph.pb', 'wb') as f:
        f.write(frozen_graph_def.SerializeToString())

C++ で推論を行う

Freeze されたモデルを読み込みます。

const string graph_def_filename = "frozen_graph.pb";

// Setup global state for TensorFlow.
tensorflow::port::InitMain(argv[0], &argc, &argv);

// Load a frozen model
tensorflow::GraphDef graph_def;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
                                        graph_def_filename, &graph_def));

読み込んだモデルを指定して Session を作成します。

// Create a session
std::unique_ptr<tensorflow::Session> session(tensorflow::NewSession(tensorflow::SessionOptions()));
TF_CHECK_OK(session->Create(graph_def));

評価用の画像データとラベルデータを読み込み、推論を実行します。

auto test_x = read_training_file("MNIST_data/t10k-images.idx3-ubyte");
auto test_y = read_label_file("MNIST_data/t10k-labels.idx1-ubyte");
predict(session, test_x, test_y);

推論の処理は以下の通りです。

"input" に (画像ファイル数, 784) の Tensor を指定し、 "output/Softmax" の出力を取得しています。
モデルで "Dropout" を使用している場合は、 keras_learning_phasefalse を指定する必要があるようです。

void predict(const std::unique_ptr<tensorflow::Session>& session, const vector<vector<float>>& batch, const vector<float>& labels) {
  // Create an input data
  tensorflow::Tensor lp(tensorflow::DT_BOOL, tensorflow::TensorShape({}));
  lp.flat<bool>().setZero();
  vector<std::pair<string, tensorflow::Tensor>> inputs = {
    {"input", MakeTensor(batch)},
    {"batch_normalization_1/keras_learning_phase", lp}
  };

  std::vector<tensorflow::Tensor> out_tensors;

  // Predict
  TF_CHECK_OK(session->Run(inputs, {"output/Softmax"}, {}, &out_tensors));
}

"output/Softmax" からは (入力画像ファイル数, 10) の Tensor が取得できます。
得られた Tensor から精度を計算する処理は以下の通りです。

int hits = 0;
for (auto tensor : out_tensors) {
  auto items = tensor.shaped<float, 2>({static_cast<int>(batch.size()), 10});
  for (int i = 0; i < batch.size(); i++) {
    int arg_max = 0;
    float val_max = items(i, 0);
    for (int j = 0; j < 10; j++) {
      if (items(i, j) > val_max) {
        arg_max = j;
        val_max = items(i, j);
      }
    }
    if (arg_max == labels[i]) {
      hits++;
    }
  }
}
std::cout << "Accuracy: " << hits / (float)batch.size() << std::endl;
25
34
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
25
34

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?