LoginSignup
25
16

More than 5 years have passed since last update.

TensorFlow内部構造解析 (1.1) Protocol Buffers形式のデータ構造

Last updated at Posted at 2018-07-25

TensorFlowには様々なデータ構造が存在します。その中でも計算グラフの定義やチェックポイントの保存に使用されるProtocol Buffers形式で定義されたデータ構造は、TensorFlowの計算グラフの処理で使用されるものであり、最も重要なデータ構造です。
TensorFlowの内部構造解析に入る前に、Protocol Buffers形式で定義された特に重要なデータ構造を紹介し、TensorFlowで扱っているデータ構造をまとめます。

  • TensorFlow v1.9.0
    • コミットID: a44996a84b24c43cca40c685a009fd59275755ab

前知識

Protocol Buffers

Protocol Buffers1は、Googleにより開発されたシリアライズフォーマットです。C++やPython、Go、Javascriptなど、様々な言語から扱うことができます。TensorFlowはProtocol Buffersフォーマットを、チェックポイントファイルの保存や言語間でのデータ転送、計算グラフ構築処理などの様々な場所で利用します。

Protocol Buffersで定義されたファイルは、protocと呼ばれるProtocol Buffersのコンパイラを使って、各言語のソースコードを出力します。例えば、ファイル名が attr_value.proto であるProtocol Buffers形式の定義ファイルを、protocを使ってPython向けにコードを生成した時、Pythonのソースコード attr_value_pb2.py が生成されます。生成された attr_value_pb2 モジュールを使って、Protocol Buffersで定義されたデータ構造にアクセスすることができます。

グラフを表現するデータ構造

TensorFlow内でProtocol Buffersで定義されるグラフを表現するデータ構造は、GraphDef です。GraphDef は、グラフのノードを表現するデータ構造 NodeDef から構成されます。

GraphDef.png

GraphDef

GraphDef は、TensorFlow内の計算グラフを表現するデータ構造で、Protocol Buffersで定義されます。

複数のノードから構成される計算グラフは、計算グラフに含まれるノードを表現するデータ構造 NodeDef のリストを保持しています。また、GraphDef は、TensorFlowのSavedModel 2 のグラフ情報を保存する時に使用する情報であるため、互換性の確認に必要なバージョン情報が含まれています。

tensorflow/core/framework/graph.proto
message GraphDef {
  repeated NodeDef node = 1;
  VersionDef versions = 4;
  int32 version = 3 [deprecated = true];
  FunctionDefLibrary library = 2;
};
フィールド 意味
node グラフのノードを表現する NodeDef のリスト
versions バージョン
version バージョン(Deprecated)
library ユーザ定義の関数

NodeDef

NodeDef は、グラフのノードを表現するデータ構造で、Protocol Buffersで定義されます。

ノード名やノードに対応する演算(Operation)、入力ノードなど、TensorFlowの計算グラフを構築するために必要な情報が含まれています。

tensorflow/core/framework/node_def.proto
message NodeDef {
  string name = 1;
  string op = 2;
  repeated string input = 3;
  string device = 4;
  map<string, AttrValue> attr = 5;
};
フィールド 意味
name ノードを一意に特定するための名前
op ノードに紐づいたOperation名
input 入力ノードのリスト
device ユーザからリクエストされた、演算を実行するデバイス
attr 属性情報

AttrValue

AttrValue は、値を保持するデータ構造で、Protocol Buffersで定義されます。

oneof 修飾子により、ある1つのデータ型のデータを保持することができます。

tensorflow/core/framework/attr_value.proto
message AttrValue {
  message ListValue {
    repeated bytes s = 2;
    repeated int64 i = 3 [packed = true];
    repeated float f = 4 [packed = true];
    repeated bool b = 5 [packed = true];
    repeated DataType type = 6 [packed = true];
    repeated TensorShapeProto shape = 7;
    repeated TensorProto tensor = 8;
    repeated NameAttrList func = 9;
  }

  oneof value {
    bytes s = 2;
    int64 i = 3;
    float f = 4;
    bool b = 5;
    DataType type = 6;
    TensorShapeProto shape = 7;
    TensorProto tensor = 8;
    ListValue list = 1;
    NameAttrList func = 10;
    string placeholder = 9;
  }
}
フィールド
s 文字列
i 整数値
f 浮動小数点数
b ブール値
type DataType(TensorFlow内部で使用されるデータ型)
shape TensorShapeProto(テンソルの形)
tensor TensorProto(テンソルの実体)
list リスト(ListValue
func ユーザ定義関数
placeholder ユーザ定義関数内で利用されるPlaceholder

テンソルを表現するデータ構造

TensorFlow内でProtocol Buffersで定義されるテンソルを表現するデータ構造は、TensorProto です。TensorProto は、テンソルの形を表現するデータ構造 TensorShapeProto とテンソルの各要素のデータから構成されます。

TensorProto.png

TensorProto

TensorProto は、テンソルを表現するためのデータ構造で、Protocol Buffersで定義されます。

フィールド dtype によって TensorProto に保存されるデータ型が決まり、dtype に対応する <type>_val に、テンソルの各要素の値が保持されます。(<type> は各データ型)また、<type>_val をシリアライズしたデータが tensor_content に保持されています。

tensorflow/core/framework/tensor.proto
message TensorProto {
  DataType dtype = 1;
  TensorShapeProto tensor_shape = 2;
  int32 version_number = 3;
  bytes tensor_content = 4;
  repeated int32 half_val = 13 [packed = true];
  repeated float float_val = 5 [packed = true];
  repeated double double_val = 6 [packed = true];
  repeated int32 int_val = 7 [packed = true];
  repeated bytes string_val = 8;
  repeated float scomplex_val = 9 [packed = true];
  repeated int64 int64_val = 10 [packed = true];
  repeated bool bool_val = 11 [packed = true];
  repeated double dcomplex_val = 12 [packed = true];
  repeated ResourceHandleProto resource_handle_val = 14;
  repeated VariantTensorDataProto variant_val = 15;
  repeated uint32 uint32_val = 16 [packed = true];
  repeated uint64 uint64_val = 17 [packed = true];
};
フィールド 意味
dtype データ型(DataType
tensor_shape テンソルの形(TensorShapeProto
version_number バージョン
tensor_content シリアライズ形式のデータ

TensorShapeProto

TensorShapeProto は、TensorProto で表現されるテンソルの形を表現する、Protocol Buffersで定義されたデータ構造です。

フィールド dim がテンソルの形を表し、テンソルの各軸の要素数が保存されています。

tensorflow/core/framework/tensor_shape.proto
message TensorShapeProto {
  message Dim {
    int64 size = 1;
    string name = 2;
  };
  repeated Dim dim = 2;
  bool unknown_rank;
};

ファンクションを表現するデータ構造

TensorFlowには、ファンクションと呼ばれる機能があります。ファンクションは、複数の演算を1つのノードとして見せることができる機能で、ユーザはデコレータtensorflow.python.framework.function.Defun() を利用することで、任意のファンクションを定義できます。

FunctionDef

FunctionDef は、ユーザが定義したファンクションを表現するデータ構造で、Protocol Buffersで定義されます。

ファンクションを一意に決定するために必要な名前(シグネイチャ)と、ファンクションの定義を示すためのノード構成情報を持ちます。reserved は過去に存在していたフィールドですが、TensorFlowのバージョンアップとともに不要になったため、削除されたようです。互換性を取るため、予約領域としてフィールドに残っているものと思われます。

tensorflow/core/framework/function.proto
message FunctionDef {
  OpDef signature = 1;
  map<string, AttrValue> attr = 5;
  reserved 2;
  repeated NodeDef node_def = 3;
  map<string, string> ret = 4;
};
フィールド 意味
signature ファンクションを一意に決定するシグネイチャ(OpDef
attr 属性情報
node_def ファンクションを構成するノード(NodeDef
ret 戻り値

演算を表現するデータ構造

TensorFlowの計算グラフは、NodeDef のフィールド op により、演算と1対1に結びついています。計算グラフは、演算を表現するデータ構造 OpDefKernelDef と照合することで、NodeDef で指定された演算が、TensorFlowでサポートされているかを確認します。

Operation.png

OpDef

OpDef は、TensorFlowにおける演算(Operation)仕様を表現する、Protocol Buffersで定義されたデータ構造です。

ユーザが定義した計算グラフの各ノード(NodeDef)で定義された演算が、OpDef で定義された演算仕様を満たしているかを確認します。OpDef で定義された演算仕様を満たさない演算を要求すると、プログラムが異常終了します。
入出力データを示すフィールド input_argoutput_arg に加えて、グラフ最適化処理などで利用されるフィールドが定義されています。

tensorflow/core/framework/op_def.proto
message OpDef {
  string name = 1;
  repeated ArgDef input_arg = 2;
  repeated ArgDef output_arg = 3;
  repeated AttrDef attr = 4;
  OpDeprecation deprecation = 8;
  string summary = 5;
  string description = 6;
  bool is_commutative = 18;
  bool is_aggregate = 16;
  bool is_stateful = 17;
  bool allows_uninitialized_input = 19;
};
フィールド 意味
name 演算名
input_arg 入力データ
output_arg 出力データ
attr 属性情報
deprecation Deprecatedであるデータ
summary 演算の説明
description 演算の説明(summary の長文版)
is_commutative 入力データの入力番号を交換可能な演算の場合は、true
aggregate 複数の入力データから1つの出力データを出力する演算の場合は、true
is_stateful ステートフルな演算の場合は、true
allows_uninitialized_input 未初期化の入力データを許す場合は true

KernelDef

KernelDef は、演算本体(OpKernel)の定義を表現する、Protocol Buffersで定義されたデータ構造です。

ユーザが定義した計算グラフの各ノード(NodeDef)の情報を KernelDef の情報と照合させ、マッチした KernelDef に対応する、演算本体であるクラス OpKernel を取得します。
NodeDef で定義した演算に一致する特定の OpKernel を取得する必要があるため、KernelDef にはデバイス情報や制約条件が保存されています。仮に、NodeDef に保持された情報を満たす KernelDef が存在しない場合は、プログラムが異常終了します。

tensorflow/core/framework/kernel_def.proto
message KernelDef {
  string op = 1;
  string device_type = 2;
  message AttrConstraint {
    string name = 1;
    AttrValue allowed_values = 2;
  }
  repeated AttrConstraint constraint = 3;
  repeated string host_memory_arg = 4;
  string label = 5;
};
フィールド 意味
op 演算名
device_type 演算実行デバイス
constraint 属性情報の制約条件
host_memory_arg ホストメモリを利用する入力データ/出力データ
label (不明)
25
16
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
25
16