以前、『TensorFlowの学習済みモデルを拾ってきてiOSで利用する』という記事を書いたのですが、そのとき用いたモデルデータはいずれも Protocol Buffers 形式でした。
(YOLOモデルでリアルタイム一般物体認識)
Protocol Buffersはプラットフォーム等を問わない汎用的なフォーマットですが、上の記事を書いたときにわからなかったのが、.pb
でエクスポートされたTensorFlow用のモデルでも、iOSで使えるものと使えないものの違いはあるのかないのか、あるとしたら何なのか(単にサイズや使用メモリ量とかが制限になってくるのか)、という点です。
理解の手がかりになるかわかりませんが、TensorFlowのiOSサンプルがどうやってモデルデータを読み込んでいるか、コードをちょっとだけ追ってみます。
なお、今回利用するサンプルはsimpleです。
##.pbファイル読み込み〜TensorFlowグラフ生成
###1. モデルのファイルパスを取得
まず、.pb
ファイルのファイル名を渡しているところはここ。
NSString* network_path = FilePathForResourceName(@"tensorflow_inception_graph", @"pb");
このFilePathForResourceName()
は次のように実装されています。
NSString* FilePathForResourceName(NSString* name, NSString* extension) {
NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension];
if (file_path == NULL) {
LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "."
<< [extension UTF8String] << "' in bundle.";
}
return file_path;
}
アプリバンドル内にあるファイルのパスを取得しているだけ、ということがわかります。
###2. Protocol Buffersをデコード
そして次に、そのファイルパスと、tensorflow::GraphDef
へのポインタをPortableReadFileToProto()
という関数の引数に渡しています。
PortableReadFileToProto([network_path UTF8String], &tensorflow_graph);
この関数の中身を見てみると、こんな感じです。
bool PortableReadFileToProto(const std::string& file_name,
::google::protobuf::MessageLite* proto) {
::google::protobuf::io::CopyingInputStreamAdaptor stream(
new IfstreamInputStream(file_name));
stream.SetOwnsCopyingStream(true);
// TODO(jiayq): the following coded stream is for debugging purposes to allow
// one to parse arbitrarily large messages for MessageLite. One most likely
// doesn't want to put protobufs larger than 64MB on Android, so we should
// eventually remove this and quit loud when a large protobuf is passed in.
::google::protobuf::io::CodedInputStream coded_stream(&stream);
// Total bytes hard limit / warning limit are set to 1GB and 512MB
// respectively.
coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20);
return proto->ParseFromCodedStream(&coded_stream);
}
コメントにも書いてありますが、SetTotalBytesLimit()
という関数を用いて、読み込むモデルファイルのサイズを1024MB(1GB)に制限(512MB以上でwarning)しています。**cameraサンプルでも同様の実装**になっていました。
ここではファイルを読み込み、ParseFromCodedStream()
でProtocol Buffersをデコードしています。
ちなみに、第2引数が::google::protobuf::MessageLite
となっていますが、tensorflow::GraphDef
の定義を見ると、次のようになっています。
class GraphDef : public ::google::protobuf::Message
GraphDef
はprotobuf::Message
を継承しているクラスであることがわかります。
###3. グラフを作成
最後に、Session::Create()
で、セッションで使用するグラフを作成します。
tensorflow::Status s = session->Create(tensorflow_graph);
##グラフの実行
ここまで、モデルに依存する部分はファイル名だけだったわけですが、グラフをRunするところでは、次のように入力・出力ノード名をそれぞれ文字列で渡しています。
std::string input_layer = "input";
std::string output_layer = "output";
std::vector<tensorflow::Tensor> outputs;
tensorflow::Status run_status = session->Run({{input_layer, image_tensor}},
{output_layer}, {}, &outputs);
このへんはモデル(グラフ)に依存してるところで、先日の公開されているモデルを利用する記事でも、モデルに応じて変えていました。
##グラフの中身を出力してみる
先程、tensorflow::GraphDef
はprotobuf::Message
のサブクラスである、と書きました。
Message
には、DebugString()
という、message(ここではProtocol Buffersをデコードして得られたデータ)を人間に読める形で出力してくれるメソッドが用意されています。
// Generates a human readable form of this message, useful for debugging
// and other purposes.
string DebugString() const;
で、これをコンソールに出力してくれるメソッドもあります。
void PrintDebugString() const;
これを使って出力してみます。
tensorflow_graph.PrintDebugString();
node {
name: "input"
op: "Placeholder"
device: "/cpu:0"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
}
}
}
}
node {
name: "conv2d0_w"
op: "Const"
device: "/cpu:0"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 7
}
dim {
size: 7
}
dim {
size: 3
}
dim {
size: 64
}
}
tensor_content: "{超長いバイナリ(?)データ}"
}
}
}
}
node {
name: "conv2d0_b"
op: "Const"
device: "/cpu:0"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 64
}
}
tensor_content: "{短めのバイナリ(?)データ}"
}
}
}
}
node {
name: "conv2d1_w"
op: "Const"
device: "/cpu:0"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 1
}
dim {
size: 1
}
dim {
size: 64
}
dim {
size: 64
}
}
tensor_content: "{かなり長いバイナリ(?)データ}"
}
}
}
}
(以下略)
ということで、グラフの中身が見れて、"input"というノード名も見つけられました。.pb
形式のTensorFlow用学習済みモデルが公開されているものの、Run()
の引数に渡すノード名がわからないということもあったので、ここらへんを手がかりにできそうです。
##まとめ
TensorFlowのiOSサンプルのモデルまわりのコードを追ってみました。こんなことをするよりもそもそもちゃんとTensorFlowや機械学習について理解しなさいという話はありつつも、前回時点では(自分にとって)完全なブラックボックスだったモデルというものが、
- Protocol Buffersというフォーマット自体はプラットフォームを問わない汎用的なものである
- 1024MBのモデルのサイズ制限は、サンプルアプリ側で行っている(TensorFlow for iOSの制約ではない)
- モデルを読み込みグラフを作成するところまでは、どのモデルも同様の手順である
- Runするところでモデル(グラフ)に依存した手続きが必要
-
protobuf::Message::PrintDebugString()
でProtocol Buffersからデコードしてきたデータの中身を見れるので、他人が作ったモデル(グラフ)を使うときにここをヒントにできそう
-
と色々と手がかりが得られました。