Help us understand the problem. What is going on with this article?

TensorFlow for iOS のモデルについて

More than 3 years have passed since last update.

以前、『TensorFlowの学習済みモデルを拾ってきてiOSで利用する』という記事を書いたのですが、そのとき用いたモデルデータはいずれも Protocol Buffers 形式でした。

(YOLOモデルでリアルタイム一般物体認識)

Protocol Buffersはプラットフォーム等を問わない汎用的なフォーマットですが、上の記事を書いたときにわからなかったのが、.pbでエクスポートされたTensorFlow用のモデルでも、iOSで使えるものと使えないものの違いはあるのかないのか、あるとしたら何なのか(単にサイズや使用メモリ量とかが制限になってくるのか)、という点です。

理解の手がかりになるかわかりませんが、TensorFlowのiOSサンプルがどうやってモデルデータを読み込んでいるか、コードをちょっとだけ追ってみます。

なお、今回利用するサンプルはsimpleです。

.pbファイル読み込み〜TensorFlowグラフ生成

1. モデルのファイルパスを取得

まず、.pbファイルのファイル名を渡しているところはここ。

RunModelViewController.mm
NSString* network_path = FilePathForResourceName(@"tensorflow_inception_graph", @"pb");

このFilePathForResourceName()は次のように実装されています。

RunModelViewController.mm
NSString* FilePathForResourceName(NSString* name, NSString* extension) {
  NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension];
  if (file_path == NULL) {
    LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "."
           << [extension UTF8String] << "' in bundle.";
  }
  return file_path;
}

アプリバンドル内にあるファイルのパスを取得しているだけ、ということがわかります。

2. Protocol Buffersをデコード

そして次に、そのファイルパスと、tensorflow::GraphDefへのポインタをPortableReadFileToProto()という関数の引数に渡しています。

RunModelViewController.mm
PortableReadFileToProto([network_path UTF8String], &tensorflow_graph);

この関数の中身を見てみると、こんな感じです。

RunModelViewController.mm
bool PortableReadFileToProto(const std::string& file_name,
                             ::google::protobuf::MessageLite* proto) {
  ::google::protobuf::io::CopyingInputStreamAdaptor stream(
      new IfstreamInputStream(file_name));
  stream.SetOwnsCopyingStream(true);
  // TODO(jiayq): the following coded stream is for debugging purposes to allow
  // one to parse arbitrarily large messages for MessageLite. One most likely
  // doesn't want to put protobufs larger than 64MB on Android, so we should
  // eventually remove this and quit loud when a large protobuf is passed in.
  ::google::protobuf::io::CodedInputStream coded_stream(&stream);
  // Total bytes hard limit / warning limit are set to 1GB and 512MB
  // respectively. 
  coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20);
  return proto->ParseFromCodedStream(&coded_stream);
}

コメントにも書いてありますが、SetTotalBytesLimit()という関数を用いて、読み込むモデルファイルのサイズを1024MB(1GB)に制限(512MB以上でwarning)しています。cameraサンプルでも同様の実装になっていました。

ここではファイルを読み込み、ParseFromCodedStream()でProtocol Buffersをデコードしています。

ちなみに、第2引数が::google::protobuf::MessageLiteとなっていますが、tensorflow::GraphDefの定義を見ると、次のようになっています。

graph.pb.h
class GraphDef : public ::google::protobuf::Message

GraphDefprotobuf::Messageを継承しているクラスであることがわかります。

3. グラフを作成

最後に、Session::Create()で、セッションで使用するグラフを作成します。

RunModelViewController.mm
tensorflow::Status s = session->Create(tensorflow_graph);

グラフの実行

ここまで、モデルに依存する部分はファイル名だけだったわけですが、グラフをRunするところでは、次のように入力・出力ノード名をそれぞれ文字列で渡しています。

RunModelViewController.mm
std::string input_layer = "input";
std::string output_layer = "output";
std::vector<tensorflow::Tensor> outputs;
tensorflow::Status run_status = session->Run({{input_layer, image_tensor}},
                           {output_layer}, {}, &outputs);

このへんはモデル(グラフ)に依存してるところで、先日の公開されているモデルを利用する記事でも、モデルに応じて変えていました。

グラフの中身を出力してみる

先程、tensorflow::GraphDefprotobuf::Messageのサブクラスである、と書きました。

Messageには、DebugString()という、message(ここではProtocol Buffersをデコードして得られたデータ)を人間に読める形で出力してくれるメソッドが用意されています。

// Generates a human readable form of this message, useful for debugging
// and other purposes.
string DebugString() const;

で、これをコンソールに出力してくれるメソッドもあります。

void PrintDebugString() const;

これを使って出力してみます。

RunModelViewController.mm
tensorflow_graph.PrintDebugString();
node {
  name: "input"
  op: "Placeholder"
  device: "/cpu:0"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
      }
    }
  }
}
node {
  name: "conv2d0_w"
  op: "Const"
  device: "/cpu:0"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
          dim {
            size: 7
          }
          dim {
            size: 7
          }
          dim {
            size: 3
          }
          dim {
            size: 64
          }
        }
        tensor_content: "{超長いバイナリ(?)データ}"
      }
    }
  }
}
node {
  name: "conv2d0_b"
  op: "Const"
  device: "/cpu:0"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
          dim {
            size: 64
          }
        }
        tensor_content: "{短めのバイナリ(?)データ}"
      }
    }
  }
}
node {
  name: "conv2d1_w"
  op: "Const"
  device: "/cpu:0"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_FLOAT
        tensor_shape {
          dim {
            size: 1
          }
          dim {
            size: 1
          }
          dim {
            size: 64
          }
          dim {
            size: 64
          }
        }
        tensor_content: "{かなり長いバイナリ(?)データ}"
      }
    }
  }
}
(以下略)

ということで、グラフの中身が見れて、"input"というノード名も見つけられました。.pb形式のTensorFlow用学習済みモデルが公開されているものの、Run()の引数に渡すノード名がわからないということもあったので、ここらへんを手がかりにできそうです。

まとめ

TensorFlowのiOSサンプルのモデルまわりのコードを追ってみました。こんなことをするよりもそもそもちゃんとTensorFlowや機械学習について理解しなさいという話はありつつも、前回時点では(自分にとって)完全なブラックボックスだったモデルというものが、

  • Protocol Buffersというフォーマット自体はプラットフォームを問わない汎用的なものである
  • 1024MBのモデルのサイズ制限は、サンプルアプリ側で行っている(TensorFlow for iOSの制約ではない
  • モデルを読み込みグラフを作成するところまでは、どのモデルも同様の手順である
  • Runするところでモデル(グラフ)に依存した手続きが必要
    • protobuf::Message::PrintDebugString()でProtocol Buffersからデコードしてきたデータの中身を見れるので、他人が作ったモデル(グラフ)を使うときにここをヒントにできそう

と色々と手がかりが得られました。

shu223
フリーランスiOSエンジニア 著書:『iOS×BLE Core Bluetooth プログラミング』『Metal入門』『実践ARKit』『Depth in Depth』『iOSアプリ開発 達人のレシピ100』他 GitHubの累計スター数23,000超
http://shu223.hatenablog.com/
engineerlife
技術力をベースに人生を謳歌する人たちのコミュニティです。
https://community.camp-fire.jp/projects/view/280040
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした