LoginSignup
35
27

More than 5 years have passed since last update.

以前、『TensorFlowの学習済みモデルを拾ってきてiOSで利用する』という記事を書いたのですが、そのとき用いたモデルデータはいずれも Protocol Buffers 形式でした。

(YOLOモデルでリアルタイム一般物体認識)

Protocol Buffersはプラットフォーム等を問わない汎用的なフォーマットですが、上の記事を書いたときにわからなかったのが、.pbでエクスポートされたTensorFlow用のモデルでも、iOSで使えるものと使えないものの違いはあるのかないのか、あるとしたら何なのか(単にサイズや使用メモリ量とかが制限になってくるのか)、という点です。

理解の手がかりになるかわかりませんが、TensorFlowのiOSサンプルがどうやってモデルデータを読み込んでいるか、コードをちょっとだけ追ってみます。

なお、今回利用するサンプルはsimpleです。

.pbファイル読み込み〜TensorFlowグラフ生成

1. モデルのファイルパスを取得

まず、.pbファイルのファイル名を渡しているところはここ。

RunModelViewController.mm
NSString* network_path = FilePathForResourceName(@"tensorflow_inception_graph", @"pb");

このFilePathForResourceName()は次のように実装されています。

RunModelViewController.mm
NSString* FilePathForResourceName(NSString* name, NSString* extension) {
  NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension];
  if (file_path == NULL) {
    LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "."
           << [extension UTF8String] << "' in bundle.";
  }
  return file_path;
}

アプリバンドル内にあるファイルのパスを取得しているだけ、ということがわかります。

2. Protocol Buffersをデコード

そして次に、そのファイルパスと、tensorflow::GraphDefへのポインタをPortableReadFileToProto()という関数の引数に渡しています。

RunModelViewController.mm
PortableReadFileToProto([network_path UTF8String], &tensorflow_graph);

この関数の中身を見てみると、こんな感じです。

RunModelViewController.mm
bool PortableReadFileToProto(const std::string& file_name,
                             ::google::protobuf::MessageLite* proto) {
  ::google::protobuf::io::CopyingInputStreamAdaptor stream(
      new IfstreamInputStream(file_name));
  stream.SetOwnsCopyingStream(true);
  // TODO(jiayq): the following coded stream is for debugging purposes to allow
  // one to parse arbitrarily large messages for MessageLite. One most likely
  // doesn't want to put protobufs larger than 64MB on Android, so we should
  // eventually remove this and quit loud when a large protobuf is passed in.
  ::google::protobuf::io::CodedInputStream coded_stream(&stream);
  // Total bytes hard limit / warning limit are set to 1GB and 512MB
  // respectively. 
  coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20);
  return proto->ParseFromCodedStream(&coded_stream);
}

コメントにも書いてありますが、SetTotalBytesLimit()という関数を用いて、読み込むモデルファイルのサイズを1024MB(1GB)に制限(512MB以上でwarning)しています。cameraサンプルでも同様の実装になっていました。

ここではファイルを読み込み、ParseFromCodedStream()でProtocol Buffersをデコードしています。

ちなみに、第2引数が::google::protobuf::MessageLiteとなっていますが、tensorflow::GraphDefの定義を見ると、次のようになっています。

graph.pb.h
class GraphDef : public ::google::protobuf::Message

GraphDefprotobuf::Messageを継承しているクラスであることがわかります。

3. グラフを作成

最後に、Session::Create()で、セッションで使用するグラフを作成します。

RunModelViewController.mm
tensorflow::Status s = session->Create(tensorflow_graph);

グラフの実行

ここまで、モデルに依存する部分はファイル名だけだったわけですが、グラフをRunするところでは、次のように入力・出力ノード名をそれぞれ文字列で渡しています。

RunModelViewController.mm
std::string input_layer = "input";
std::string output_layer = "output";
std::vector<tensorflow::Tensor> outputs;
tensorflow::Status run_status = session->Run({{input_layer, image_tensor}},
                           {output_layer}, {}, &outputs);

このへんはモデル(グラフ)に依存してるところで、先日の公開されているモデルを利用する記事でも、モデルに応じて変えていました。

グラフの中身を出力してみる

先程、tensorflow::GraphDefprotobuf::Messageのサブクラスである、と書きました。

Messageには、DebugString()という、message(ここではProtocol Buffersをデコードして得られたデータ)を人間に読める形で出力してくれるメソッドが用意されています。

// Generates a human readable form of this message, useful for debugging
// and other purposes.
string DebugString() const;

で、これをコンソールに出力してくれるメソッドもあります。

void PrintDebugString() const;

これを使って出力してみます。

RunModelViewController.mm
tensorflow_graph.PrintDebugString();
node {
  name: "input"
  op: "Placeholder"
  device: "/cpu:0"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
      }
    }
  }
}
node {
  name: "conv2d0_w"
  op: "Const"
  device: "/cpu:0"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
          dim {
            size: 7
          }
          dim {
            size: 7
          }
          dim {
            size: 3
          }
          dim {
            size: 64
          }
        }
        tensor_content: "{超長いバイナリ(?)データ}"
      }
    }
  }
}
node {
  name: "conv2d0_b"
  op: "Const"
  device: "/cpu:0"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
          dim {
            size: 64
          }
        }
        tensor_content: "{短めのバイナリ(?)データ}"
      }
    }
  }
}
node {
  name: "conv2d1_w"
  op: "Const"
  device: "/cpu:0"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
          dim {
            size: 1
          }
          dim {
            size: 1
          }
          dim {
            size: 64
          }
          dim {
            size: 64
          }
        }
        tensor_content: "{かなり長いバイナリ(?)データ}"
      }
    }
  }
}
(以下略)

ということで、グラフの中身が見れて、"input"というノード名も見つけられました。.pb形式のTensorFlow用学習済みモデルが公開されているものの、Run()の引数に渡すノード名がわからないということもあったので、ここらへんを手がかりにできそうです。

まとめ

TensorFlowのiOSサンプルのモデルまわりのコードを追ってみました。こんなことをするよりもそもそもちゃんとTensorFlowや機械学習について理解しなさいという話はありつつも、前回時点では(自分にとって)完全なブラックボックスだったモデルというものが、

  • Protocol Buffersというフォーマット自体はプラットフォームを問わない汎用的なものである
  • 1024MBのモデルのサイズ制限は、サンプルアプリ側で行っている(TensorFlow for iOSの制約ではない
  • モデルを読み込みグラフを作成するところまでは、どのモデルも同様の手順である
  • Runするところでモデル(グラフ)に依存した手続きが必要
    • protobuf::Message::PrintDebugString()でProtocol Buffersからデコードしてきたデータの中身を見れるので、他人が作ったモデル(グラフ)を使うときにここをヒントにできそう

と色々と手がかりが得られました。

35
27
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
35
27