この記事は(書くと言ってすっかり忘れていた)TensorFlow Lite(以下TFLite)のC++ APIの解説をまとめたものです。Java APIは別の記事にします。あるいはJava APIの基本的な使い方は@tchkwkzkさんが記事にまとめられているのでそちらを参考にしてください。
APIは今後変更される可能性があることに注意してください。
記事投稿現在、TFLiteはドキュメントが十分に整備されていません。最低限のAPI解説(en)はここにあります。ここに記載された内容のみでも最低限のものは書けますが、色々足りてません。
- エラーハンドリング
- モデルデータの作成と読み込み
- インタプリタの作成とグラフの実行
- カスタムオペレーションの実装と登録/カスタムリゾルバの作成
前提条件
TFLiteはGNU C++11を前提に記述されています。またTFLiteは32bitアライメントを想定しているので注意してください。サポートされているプラットフォームはAndroid, iOS, Raspberry Piです。
エラーハンドリング
ErrorReporter
と、ステータスのチェックを使います。
カスタムのエラーレポータは、tflite::ErrorReporter
を継承してReport
をオーバライドします。カスタムエラーレポータを各APIにわたすと、エラー発生時に呼び出されます。デフォルトでエラーは標準エラーに出力するエラーレポータが使用されます。
struct CustomReporter : public tflite::ErrorReporter {
int Report(const char* format, va_list args) override {/*...*/}
};
一部のAPIはステータスコードが返ります。
TfLiteStatus status = //
switch(status) {
case kTfLiteOk: //
case kTfLiteError: //
}
モデルデータの作成と読み込み
FlatBufferModel
クラス
TFLiteで実行できるモデルは、FlatBuffersというデータ形式を用います。Googleの開発しているフットプリントの小さいデータ形式です。モデルデータはtflite::FlatBufferModel
クラスの静的メソッドで、ファイルまたはバイト列から読み込めます。コンストラクタは使わず、スマートポインタで取得する方が望ましいでしょう。
class FlatBufferModel {
static std::unique_ptr<FlatBufferModel> BuildFromFile(
const char* filename,
ErrorReporter* error_reporter);
static std::unique_ptr<FlatBufferModel> BuildFromBuffer(
const char* buffer,
size_t buffer_size,
ErrorReporter* error_reporter);
};
読み込み例:
auto model = tflite::FlatBufferModel::BuildFromFile("path/to/model");
if(!model) {
// 読み込み失敗
}
インタプリタの作成とグラフの実行
tflite::Interpreter
の作成
FlatBuffers
形式のTFLiteモデルを、実際に実行するインタプリタエンジンですtflite::InterpreterBuilder
で作成します。
作成時には、先程のFlatBufferModel
と、オペレーションリゾルバが必要になります。基本的に、カスタムオペレーションを用いない場合は、以下のようにビルトインのオペレーションリゾルバを何もせずそのまま渡せば問題ありません。
std::unique_ptr<tflite::Interpreter> interpreter;
tflite::ops::builtin::BuiltinOpResolver resolver;
tflite::InterpreterBuilder(*model.get(), resolver)(&interpreter);
ビルトインでサポートされているオペレーションについては、ここを参照してください。サポートされていないオペレーションを実行したい場合は、カスタムオペレーションを作成する必要があります(後述)。
グラフの実行
// Tensorをメモリ上に確保します
interpreter->AllocateTensors();
// 0番目の入力を取得します
// 入力が複数ある場合は同様にインデクスで取得できます
float* input = interpreter->typed_input_tensor<float>(0);
// 取得した入力となるTensorのデータへのポインタを介して、入力を代入します
// グラフを実行します
interpreter->Invoke();
// 出力の値も、入力と同じように取得できます
float* output = interpreter->typed_output_tensor<float>(0);
tflite::Interpreter
のメンバ関数
AllocateTensors
TfLiteStatus AllocateTensors();
テンソルのメモリを確保します。
Invoke
TfLiteStatus Invoke();
モデルの計算を実行します。これを呼び出す前に、入力データを代入しておく必要があります。
inputs
/outputs
const std::vector<int>& inputs() const;
const std::vector<int>& outputs() const;
入出力テンソルのインデックスを取得します。テンソルの値ではありません。入出力は、それぞれ1つとは限らないので、ベクトルで返ります。
TFLiteインタプリタ内部に確保されたテンソルは、すべてインデックスが割り当てられています。インデックスを通して、テンソルの情報やデータを操作します。
tensor
/typed_tensor
TfLiteTensor* tensor(int tensor_index);
const TfLiteTensor* tensor(int tensor_index) const;
template <class T>
T* typed_tensor(int tensor_index);
template <class T>
const T* typed_tensor(int tensor_index) const;
テンソルのインデックスから、テンソルの情報を保持しているTfLiteTensor
へのポインタを取得します。TfLiteTensor
から型情報や形状情報、生のデータへのポインタを取得できます。実際に生データを操作する場合は、生ポインタを型情報に基づいてキャストする必要がありますが、その場合はtyped_tensor
を使って安全にキャストされたポインタを取得しましょう。
typed_input_tensor
/typed_output_tensor
template <class T>
T* typed_input_tensor(int index);
template <class T>
T* typed_output_tensor(int index);
モデルの入出力テンソルを取得するユーティリティです。引数のインデックスは、内部のテンソルインデックスではありません。モデルの入出力テンソルの順番を指定します。例えば、float型のモデルの1番目の入力のテンソルを取得するときは、typed_input_tensor<float>(0)
のようにします。なお、データ型が未知の場合は、型情報を使って条件分岐しましょう。
tensors_size
/nodes_size
size_t tensors_size() const;
size_t nodes_size() const;
インタプリタに読み込まれているモデルのテンソルの数とノードの数を取得します。
ResizeInputTensor
TfLiteStatus ResizeInputTensor(int tensor_index, const std::vector<int>& dims);
インデクスで指定したテンソルの入力の形状を変更します。
UseNNAPI
void UseNNAPI(bool enable);
Android OS環境下において、NNAPIによるハードウェアクセラレーションを有効にする場合、true
を渡します。
SetNumThreads
void SetNumThreads(int num_threads);
マルチスレッドで実行したい場合、スレッド数を指定します。
(補足)TfLiteTensor
について
typedef struct {
TfLiteType type;
TfLitePtrUnion data;
TfLiteIntArray* dims;
TfLiteQuantizationParams params;
TfLiteAllocationType allocation_type;
size_t bytes;
const void* allocation;
const char* name;
TfLiteDelegate* delegate;
TfLiteBufferHandle buffer_handle;
bool data_is_stale;
} TfLiteTensor;
TFLiteが保持するテンソルの情報を格納している、Cメモリレイアウトの構造体です。基本的に、値の取得はインタプリタのtyped_tensor
を用いましょう。入力データのバリデーションチェックを行う場合などは、型の情報(type
)や、形状(dims
)を参照しましょう。
TfLiteIntArray
は、以下のような単純なC構造体です。範囲チェックに気をつけましょう。
typedef struct {
int size;
int data[];
} TfLiteIntArray;
カスタムオペレーションの実装と登録/カスタムリゾルバの作成
オペレーションは、以下の構造体で表されます。
typedef struct {
void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
void (*free)(TfLiteContext* context, void* buffer);
TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
} TfLiteRegistration;
見ての通り、4つの関数ポインタを保持します。モデルがインタプリタに読み込まれた時、各ノードごとに対応するオペレーションのinit
が実行されます。ここで必要なメモリを確保します。各init
はそれぞれ対応するfree
が一度だけ呼び出されます。インタプリタの構築後、入力テンソルの形状が変化した場合などに、prepare
が呼び出されます。ここで適切にメモリを開放/再確保する必要があります。そして、実際にインタプリタのInvoke
によってグラフの各ノードが実行される時、invoke
が実行されます。使用しない関数はnullptr
を指定します。
この記事書いてるときにカスタムオペレーションの作成におけるベストプラクティスや追加の属性の付け方とかが追記されたのでその辺りも参考にしてください。
作成したカスタムオペレーションは、リゾルバに登録することで間接的にインタプリタで実行可能になります。
TfLiteRegistration* Register_MY_CUSTOM_OP() {
static TfLiteRegistration r = {/*...*/};
return &r;
}
resolver.AddOp("MY_CUSTOM_OP", Register_MY_CUSTOM_OP());
このリゾルバは、ビルトインオペレーションが全て登録されたtflite::ops::builtin::BuiltinOpResolver
を用いる他に、空のリゾルバから必要なオペレーションのみを追加してカスタムリゾルバを作成することもできます。これは、リソース制約のある環境下において、使用される予定のないオペレーションの実装をバイナリから省きたい場合などに有効な選択肢です。
(プロファイリング)
最近のバージョンでうまく動かせなくなったので動かせたら書きます…
サンプル
GitHubに、落書きを認識するモデルをラズパイで実行するサンプルを置いてあります。
主にTFLiteを使っているのはmain.cppです。
以上