LoginSignup
6
1

More than 3 years have passed since last update.

Window10で TensorFlow liteを使ってみる - 後編

Last updated at Posted at 2020-10-03

1.前編のふりかえり

◆◆Window10で TensorFlow liteを使ってみる - 前編
前編では、MinGW64を用いて Windows下で利用できる TensorFlow liteライブラリを作成した。併せて LeNet MINSTの学習済みモデルを作成し、TensorFlow liteの簡単な動作確認を行った。

本編では、関数型言語ElixirとTensorFlow liteを組み合わせ、簡単な手書き数字認識アプリ PlugMnistを作ってみる。Exlixirと TensorFlow liteのドッキングには、Elixir/Erlangの他言語インターフェイス Portsを利用する。また、TensorFlow liteに入力する画像の加工には CImgライブラリを用いる。

2.Elixir/Erlangの他言語インターフェイス Ports

Elixir/Erlangは、C言語などを利用して機能拡張ができるように他言語インターフェイスを備えている。大きく Ports と NIFs と言う二つのインターフェイス手段がある[*1]。今回はそれらのうちの Ports を使って MNIST/TensorFlow lite を Elixirから利用してみよう。

[*1]Elixir/Erlangには、シームレスに分散環境に拡張できる強力なプロセス間通信が組み込まれている。その通信ノードを他言語で記述するという手段もあるが、まだ試しことがない。

それでは、iexを起動して Portsの動きを体感しおこう。
少し見づらいが、下の例では "cat"コマンドを Portとして Elixirに接続し、文字列"hello world!"(12文字)の送受信を行っている。このように、Portsを用いた機能拡張では、その実行部をOSの実行可能ファイルとして実装する。そして、Elixirとの通信は stdin/stdoutを介して行うことになる。

$ iex
Interactive Elixir (1.10.4) - press Ctrl+C to exit (type h() ENTER for help)
iex(1)> port = Port.open({:spawn, "cat"}, [:binary])
#Port<0.4>
iex(2)> Port.command(port, "hello world!")
true
iex(3)> flush
{#Port<0.4>, {:data, "hello world!"}}
:ok
iex(4)> Port.info(port)
[
  name: 'cat',
  links: [#PID<0.102.0>],
  id: 32,
  connected: #PID<0.102.0>,
  input: 12,
  output: 12,
  os_pid: 10852
]
iex(5)> Port.close(port)
true

この例では、(1)コマンド起動モードを :spawn、(2)データ形式を :binary、(3)通信プロトコルを :stream(デフォルト) に設定して Portを開いている(Port.open)。開いた Portの IDは #Port<0.4>だ。起動モード :spawnはシェルの助けを借りて PATHリスト上でコマンドを検索し起動するモード、プロトコル :streamは無手順でデータを垂れ流しするプロトコルだ。詳しくは、参考文献[3]を参照してほしい。

またここでは、Port.command()でデータを送信しているが、データ送受信の実体は"メッセージ"なので、よりプリミティブな send()で送信しても良い(次の例)。事実、受信データについては iexに溜まっていたメッセージを flushコマンドで吐き出してみせた。十中八九、本番の設計では、受信データは receiveガードで受け取ることになるだろう。

ふむ、参考文献[3]によると、Port.open()にはいろいろなオプションを指定することができる様だ。とはいえ、すべてのオプションを試している暇はないので、通信プロトコルの3つ(:stream, :packet, :line)の違いについてだけ調べておこう。
下の例では :streamに代わって :packetを指定した。一見すると :streamの場合と何も違わないのだが、舞台裏の通信路に流れる個々のデーターには、その先頭にデータ長が付加されている。オプション"{:packet, n}"の n には、データ長を格納するエリアのバイト数{1,2,4}を指定する。数値フォーマットは、通信系ではお決まりの BIG endianだ。

iex(1)> port = Port.open({:spawn_executable, "c:/msys64/usr/bin/cat"}, [:binary, {:packet, 2}])
#Port<0.4>
iex(2)> send(port, {self(), {:command, "hello mars!"}})
{#PID<0.102.0>, {:command, "hello mars!"}}
iex(3)> flush
{#Port<0.4>, {:data, "hello mars!"}}
:ok

次に示すプロトコル :lineは、改行文字をデータの区切りとして扱う行指向のプロトコルだ。オプション"{:line, N}"の N には、データを一時的に格納する通信バッファの大きさを指定する。改行文字を受け取るか、あるいはバッファが一杯になる度にメッセージが生成される。データ・タプルに付加されている :noeolは引き続きデータが届くことを、:eolはデータの区切りまで受け取ったことを表している。指定するバッファの大きさにもよるが、少々取り扱いが面倒に思える。

iex(1)> port = Port.open({:spawn, "cat"}, [:binary, {:line, 5}])
#Port<0.4>
iex(2)> Port.command(port, "hello jupiter!")
true
iex(3)> flush
{#Port<0.4>, {:data, {:noeol, "hello"}}}
{#Port<0.4>, {:data, {:noeol, " jupi"}}}
:ok
iex(4)> Port.command(port, "\n")
true
iex(5)> flush
{#Port<0.4>, {:data, {:eol, "ter!"}}}
:ok

以上、3つの通信プロトコルを見てきたが、もっとも扱い易いプロトコルは :packetだと思う。大概の場合は :packetを用いて設計を行うことになるだろう。

3.手書き数字認識アプリ Plug MNISTの設計

それでは「Elixir × TensorFlow lite」アプリ PlugMnistの設計に取り掛かろう。
構想設計はこんな感じかな、

ブラウザ上で Canvasコントロールに手書きした数字を JPEG画像として Elixir(HTTPサーバー)に送る。Elixirは、受け取った画像を所定のファイルに保存した後、Portを介して TensorFlow liteにこの画像を推論するよう依頼する。TensorFlow liteからは推論結果が返ってくるので、それを加工してブラウザ上に表示する。

plug_mnist.jpg
そうだなぁ、コードを起こす前にもう少し設計を分解しておこう。

  • DBは不要なので、HTTPサーバーは Phoenixではなく Plugで実装する
  • 画像ファイル"img.jpg"や学習済みモデル"mnist.tfl"は privディレクトリに保存する
  • TensorFlow lite推論器のファイル名は "tfl_interp.exe"とする
  • Port.open()する際に、コマンドライン引数で使用する学習済みモデルを指定する
  • 推論器"tfl_interp.exe"が受け付けるコマンドは下の2つ
    • "predict <jpeg image>" - 指定された画像を推論する
    • "info" - 内部情報の問い合わせ  
  • 推論結果は、分類クラスをkey、その確からしさをvalueとした JSON objectで返す

こんなところでいいかな(^^)
続く章では、心臓部の TflInterp(Elixir)と tfl_interp(C++)のコードを掲載する。
なお、これらのコードを含め PlugMnistの全コードを、https://github.com/shoz-f/plug_mnist にアップしている。

4.Elixirサイド - TflInterp Port

下のコードは、ElixirからTensorFlow liteへの架け橋に当たる TflInterpモジュールのコードだ。
オープンした Portの IDを何処かに覚えておく必要があるので、GenServerで実装している。

Portのオープンは、GenServerの初期化関数 init()で行っている。返り値のPort IDは、GenServerのステートとして保持する。使用する学習済みモデルは init()の引数 - すなわち Application.start()が定義している childrenリストに指定することにした。

TflInterpから他のモジュールに exportするサービスは predict()だけだ。その実体は handle_call({:predict,..)で、Portを介して推論器"tfl_interp.exe"にコマンドを送信し、その返信を receiveガードで待っている。返信は JSON形式の文字列で届き、これを扱い易い Map型に変換している(Janson.decode)。尚、返信には :OK(推論が成功した) と :error(何らかのエラーが発生した)の2種類がある。

架け橋に必要な Elixirコードはたったコレだけ。めちゃ Simple!!!

tfl_interp.ex
defmodule PlugMnist.TflInterp do
  use GenServer

  def start_link(opts \\ []) do
    GenServer.start_link(__MODULE__, opts, name: __MODULE__)
  end

  def predict(img_file) do
    GenServer.call(__MODULE__, {:predict, img_file})
    |> IO.inspect
  end


  def init(opts) do
    executable = Application.app_dir(:plug_mnist, "priv/tfl_interp.exe")
    tfl_model  = Application.app_dir(:plug_mnist, Keyword.get(opts, :model))

    port = Port.open({:spawn_executable, executable}, [
      {:args, [tfl_model]},
      {:packet, 2},
      :binary
    ])

    {:ok, %{port: port}}
  end

  def handle_call({:predict, img_file}, _from, state) do
    Port.command(state.port, "predict #{img_file}")
    response = receive do
      {_, {:data, <<response::binary>>}} ->
        {:ok, ans} = Jason.decode(response)
        if Map.has_key?(ans, "error"), do: {:error, ans["error"]}, else: {:ok, ans}
    after
      5000 -> {:timeout}
    end

    {:reply, response, state}
  end
end

5.TensorFlow liteサイド - tfl_interpインタープリタ

次に見るのは、TensorFlow lite推論器の実装だ。見ての通り C++で記述している。Elixirとのi/f部(tfl_interp.cc)とMNISTモデルによる推論器部(predict.cc)の2つのファイルで構成することにした。仮に、別用途でモデルを取り替えることになったとしても、Elixir i/f部はそのまま流用できるようにしたいからだ。

下のコードが、Elixir i/f部"tfl_interp.cc"の全コードだ。何やらずらずらとコードが並んでいて難しそうに見える[*2]。大丈夫、そのほとんどは stdinからのコマンド文字列の待ち受け&パースと、stdoutへの推論結果の返信を下請けするヘルパー・ルーチンだ。気兼ねなく読み飛ばしOK(^^;)

メイン・ルーチンは interp()だ。少し詳しく見ておこう。まず最初に、引数で受け取ったモデル・ファイル名で tflite::interpreterのインスタンスを生成しセットアップする。この辺りのコードは、TensorFlowのソースコードに同梱の minimal.ccを参考にした。

その後、コマンド待ちループを開始し、Elixirの Portから stdinにコマンド文字列が投げられるのを待つ。コマンド文字列を受け取ったならば、その文字列をパースし "predict"または "info"を実行する。最後に、実行したコマンドの結果(json)を stdoutに出力し、再びコマンド文字列待ちに戻る。

今回の Port実行部は、1 by 1のシングルタスクで十分な仕様であったので、read() APIを用いて処理を中断ブロックすることにした。しかしながら、Nervesのi/oモジュール等に見られるように、マルチタスクが必要な場合では poll() APIの利用や Multi Thread化が必要だろう。

なお設計上のポイントとして、Elixirの Portが何らかの理由で閉じられ stdin/stdoutが閉じられた場合には、コマンド待ちループを脱出して "tfl_interp.exe"をターミネートさせねばならない点に注意しよう。さもなければゾンビが生まれてしまう。

[*2]何も考えずに GNU catのソースを参考に低レベルI/Oでガシガシしたのでコードが長くなってしまった(汗)。

tfl_interp.cc
/***  File Header  ************************************************************/
/**
* tfl_interp.cc
*
* Elixir/Erlang Port ext. of tensor flow lite
* @date create Sat Sep 26 06:26:30 JST 2020
* System       MINGW64/Windows 10<br>
*
**/
/**************************************************************************{{{*/

#include <unistd.h>
#include <iostream>

#include <cerrno>
#ifdef EINTR
#  define IS_EINTR(x) ((x) == EINTR)
#else
#  define IS_EINTR(x) 0
#endif

#include <iterator>
#include <regex>
#include <string>
using namespace std;

#include <nlohmann/json.hpp>
using json = nlohmann::json;

#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
using namespace tflite;

struct {
    string mPath;
    string mExe;
    string mTflModel;
} gInfo;

/***  Module Header  ******************************************************}}}*/
/**
* read specific sized byte from stdin
* @par DESCRIPTION
*
* @return count of received byte or error code
**/
/**************************************************************************{{{*/
ssize_t
full_read(int fd, void *buf, size_t count)
{
    char *ptr = reinterpret_cast<char *>(buf);

    size_t total = 0;
    while (count > 0) {
        ssize_t n;
        do {
            n = read(fd, buf, count);
        } while (n < 0 && IS_EINTR(errno));

        if (n < 0) {
            //error occured
            return n;
        }
        else if (n == 0) {
            // termination: total is less than required count
            return total;
        }

        total += n;
        ptr += n;
        count -= n;
    }

    return total;
}

/***  Module Header  ******************************************************}}}*/
/**
* write specific sized byte to stdout
* @par DESCRIPTION
*
* @return count of sent byte or error code
**/
/**************************************************************************{{{*/
size_t
full_write(int fd, const void *buf, size_t count)
{
    size_t total = 0;
    const char *ptr = (const char *) buf;

    while (count > 0) {
        size_t n = write(fd, ptr, count);

        if (n < 0) {
            // error occured
            return n;
        }
        else if (n == 0) {
            // termination: total is less than required count
            return total;
        }

        total += n;
        ptr += n;
        count -= n;
    }

    return total;
}

/***  Module Header  ******************************************************}}}*/
/**
* receive command packet from Elixir/Erlang
* @par DESCRIPTION
*   receive command packet and store it to "buff"
*
* @retval res >  0  success
* @retval res == 0  termination
* @retval res <  0  error
**/
/**************************************************************************{{{*/
ssize_t
rcv_packet(string& cmd_line)
{
    // receive packet size
    char big_endian[2];
    ssize_t n = full_read(STDIN_FILENO, big_endian, sizeof(big_endian));
    if (n <= 0) {
        return n;
    }
    else if (n < 2) {
        errno = ENODATA;
        return ((ssize_t)-1);
    }
    unsigned short len = (unsigned short)(big_endian[0] << 8 | big_endian[1]);

    // receive packet payload
    unique_ptr<char[]> buff(new char[len]);
    n = full_read(STDIN_FILENO, buff.get(), len);
    if (n <= 0) {
        return n;
    }
    else if (n < len) {
        errno = ENODATA;
        return ((ssize_t)-1);
    }

    // return received command line
    cmd_line.assign(buff.get(), n);
    return n;
}

/***  Module Header  ******************************************************}}}*/
/**
* send result packet to Elixir/Erlang
* @par DESCRIPTION
*   construct message packet and send it to stdout
*
* @return count of sent byte or error code
**/
/**************************************************************************{{{*/
ssize_t
snd_packet(string result)
{
    unsigned short len = result.size();

    char big_endian[2];
    big_endian[0] = 0xff & (len >> 8);
    big_endian[1] = 0xff & (len);
    result.insert(0, big_endian, sizeof(big_endian));

    return full_write(STDOUT_FILENO, result.c_str(), len+2);
}

/***  Module Header  ******************************************************}}}*/
/**
* parse command line string
* @par DESCRIPTION
*   extract command & arguments string from string
*
* @retval command string & vector of arguments
**/
/**************************************************************************{{{*/
string
parse_cmd_line(const string& cmd_line, vector<string>& args)
{
    regex reg(R"(\s+)");
    sregex_token_iterator iter(cmd_line.begin(), cmd_line.end(), reg, -1);
    sregex_token_iterator end;

    string command = *iter++;
    args.assign(iter, end);

    return command;
}

/***  Module Header  ******************************************************}}}*/
/**
* tensor flow lite interpreter
* @par DESCRIPTION
*   <<解説記入>>
**/
/**************************************************************************{{{*/
void
interp(const char* tfl_name)
{
    // initialize tensor flow lite
    unique_ptr<tflite::FlatBufferModel> model =
        tflite::FlatBufferModel::BuildFromFile(tfl_name);

    tflite::ops::builtin::BuiltinOpResolver resolver;
    InterpreterBuilder builder(*model, resolver);
    unique_ptr<Interpreter> interpreter;
    builder(&interpreter);

    interpreter->AllocateTensors();

    // REPL
    for (;;) {
        // receive command packet
        string cmd_line;
        ssize_t n = rcv_packet(cmd_line);
        if (n <= 0) {
            break;
        }

        // parse command line
        vector<string> args;
        const string command = parse_cmd_line(cmd_line, args);

        json result;
        result.clear();
        if (command == "predict") {
            extern void predict(unique_ptr<Interpreter>& interpreter, const vector<string>& args, json& result);
            predict(interpreter, args, result);
        }
        else if (command == "info") {
            result["exe"]   = gInfo.mExe;
            result["model"] = gInfo.mTflModel;
        }
        else {
            result["unknown"] = command;
        }

        n = snd_packet(result.dump());
        if (n <= 0) {
            break;
        }
    }
}

/***  Module Header  ******************************************************}}}*/
/**
* tensor flow lite for Elixir/Erlang Port ext.
* @par DESCRIPTION
*   Elixir/Erlang Port extension (experimental)
*
* @return exit status
**/
/**************************************************************************{{{*/
int
main(int argc, char* argv[])
{
    if (argc < 2) {
        // argument error
        cerr << "expect <model file>\n";
        return 1;
    }

    // save exe infomations
    gInfo.mExe.assign(argv[0]);
    gInfo.mTflModel.assign(argv[1]);

    interp(argv[1]);

    return 0;
}

/*** tfl_interp.cc ********************************************************}}}*/

推論器部(predict.cc)は下記の通り。推論器のコードは、interpreterから取り出した inputテンソルに入力画像をセットし、Invoke()で推論を実行し、最後に outputテンソルから推論結果を取り出す3ステップだけだ。その他のごちゃごちゃしたコードは、入力画像を Interpreterに渡せるように加工する画像処理。

巷でイメージ・プロセッシング/マシン・ビジョンと言えば OpenCVであろう。しかしながら、たかが画像の整形ぐらいしか行わないアプリに、重装備な OpenCVを持ち込むのは気が引けてしまうではないか。と、いう訳で lightweightな CImgというライブラリを使ってみることにした。

CImgは C++のテンプレートで実装されたライブラリだ。利用に当たっては、基本的にヘッダー・ファイル"CImg.h"をインクルードするだけでよい[*3]。OpenCVに代表される何でも揃うECモール型のライブラリとは異なり、画像のload/saveと基本的な画像操作ルーチンを備え、画像の内部データ形式を4次元配列に統一しておくので足りないモノがあれば DIYしてね💕というスタンスのライブラリだ。

下のコードでは、RGB画像からGRAY画像への変換getRGBtoGRAY()、輝度反転inverse()を自前で実装している……この程度の機能はライブラリで用意しておいて欲しいところなのだが(- -;)。まぁ、使ってみると意外と小回りが利いて良さそうだ。

[*3]実際には、2~3のライブラリをリンクする必要がある。例えばjpegファイルのロードには libjpeg.aのリンクが必要となる。

tfl_predict.cc
/***  File Header  ************************************************************/
/**
* tfl_predict.cc
*
* tensor flow lite prediction
* @date create Sat Sep 26 06:26:30 JST 2020
* System       MINGW64/Windows 10<br>
*
**/
/**************************************************************************{{{*/
#ifndef cimg_plugin

#define cimg_plugin "tfl_predict.cc"
#define cimg_use_jpeg
#include "CImg.h"
using namespace cimg_library;

#include <string>
using namespace std;

#include <nlohmann/json.hpp>
using json = nlohmann::json;

#include "tensorflow/lite/interpreter.h"
using namespace tflite;

/***  Module Header  ******************************************************}}}*/
/**
* parse command line string
* @par DESCRIPTION
*   extract command & arguments string from string
*
* @retval command string & vector of arguments
**/
/**************************************************************************{{{*/
void
predict(unique_ptr<Interpreter>& interpreter, const vector<string>& args, json& result)
{
/*PRECONDITION*/
    if (args.size() < 1) {
        result["error"] = "not enough argument";
        return;
    }
/**/

    // setup
    float* input = interpreter->typed_input_tensor<float>(0);

    try {
        CImg<unsigned char> image(args[0].c_str());
        auto gray = image.getRGBtoGRAY().resize(28,28).inverse();
        cimg_foroff(gray, i) {
            input[i] = gray[i]/255.;
        }
    }
    catch (...) {
        result["error"] = "fail CImg";
        return;
    }

    // predict
    if (interpreter->Invoke() == kTfLiteOk) {
        // get result
        float* probs = interpreter->typed_output_tensor<float>(0);
        for (int i = 0; i < 10; i++) {
            result[to_string(i)] = probs[i];
        }
    }
    else {
        result["error"] = "fail predict";
    }
}

#else
/**************************************************************************}}}*/
/*** CImg Plugins:                                                          ***/
/**************************************************************************{{{*/
CImg<T> getRGBtoGRAY()
{
    if (_spectrum != 3) {
        throw CImgInstanceException(_cimg_instance
                                    "getRGBtoGRAY(): Instance is not a RGB image.",
                                    cimg_instance);
    }
    CImg<T> res(width(), height(), depth(), 1);
    T *p1 = data(0,0,0,0), *p2 = data(0,0,0,1), *p3 = data(0,0,0,2), *Y = res.data(0,0,0,0);
    const longT whd = (longT)width()*height()*depth();
    cimg_pragma_openmp(parallel for cimg_openmp_if_size(whd,256))
    for (longT i = 0; i < whd; i++) {
        const T
          R = p1[i],
          G = p2[i],
          B = p3[i];
        Y[i] = (T)(0.299f*R + 0.587f*G + 0.114f*B);
    }
    return res;
}

/******************************************************************************/
CImg<T> inverse()
{
    if (_spectrum != 1) {
        throw CImgInstanceException(_cimg_instance
                                    "inverse(): Instance is not a RGB image.",
                                    cimg_instance);
    }
    T *p1 = data(0,0,0,0);
    const longT whd = (longT)width()*height()*depth();
    cimg_pragma_openmp(parallel for cimg_openmp_if_size(whd,256))
    for (longT i = 0; i < whd; i++) {
        p1[i] = cimg::type<T>::max() - p1[i];
    }
    return *this;
}
#endif
/*** tfl_predict.cc *******************************************************}}}*/

6.動作テスト

それでは、動作テストをしてみよう。
おおおっ、手書きで書いた"3"を正しく認識した。左のウインドには "0"~"9"のそれぞれの確からしさを表示している。めでたし、めでたし。

※ 実際に PlugMnistアプリを実行するには、makeしたり mixしたりと幾つかの手順を経る必要があるのだが、誠に勝手ながら割愛させて頂く。御免m(_ _)m
plug_mnist_test.jpg

7.展望

よしっ、最初のハードルは越えた。これからが本番だ。元々の目標は TensorFlow liteを Nervesで利用することであった。ぼちぼちと進めていこう。

1.今回作成した PlugMnistをそのまま Nerves RasPiにポーティングする
2.Nerves RasPiのカメラ・モジュールで撮影した画像で推論ができるように改造する
3.物体検出や面白そうなモデルをインストールして遊んでみる
4.liteではなく本家のTensorFlowを利用できるようにしてみる

などなど(^_-)…… AI-Car構想はどうなった(汗)

参考文献

[1] elixir documentation: Port
[2] ERLANG: Interoperability Tutorial User's Guide - Ports
[3] Erlang Run-Time System Application (ERTS) Reference Manual - open_port
[4] 実況: mutableなストレージのNIFsを実装してみる
[5] TensorFlow lite: マイクロコントローラを使ってみる
[6] The CImg Library - C++ template image processing toolkit

6
1
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
1