TensorFlowには様々なデータ構造が存在します。その中でも計算グラフの定義やチェックポイントの保存に使用されるProtocol Buffers形式で定義されたデータ構造は、TensorFlowの計算グラフの処理で使用されるものであり、最も重要なデータ構造です。
TensorFlowの内部構造解析に入る前に、Protocol Buffers形式で定義された特に重要なデータ構造を紹介し、TensorFlowで扱っているデータ構造をまとめます。
- TensorFlow v1.9.0
- コミットID: a44996a84b24c43cca40c685a009fd59275755ab
前知識
Protocol Buffers
Protocol Buffers1は、Googleにより開発されたシリアライズフォーマットです。C++やPython、Go、Javascriptなど、様々な言語から扱うことができます。TensorFlowはProtocol Buffersフォーマットを、チェックポイントファイルの保存や言語間でのデータ転送、計算グラフ構築処理などの様々な場所で利用します。
Protocol Buffersで定義されたファイルは、protocと呼ばれるProtocol Buffersのコンパイラを使って、各言語のソースコードを出力します。例えば、ファイル名が attr_value.proto
であるProtocol Buffers形式の定義ファイルを、protocを使ってPython向けにコードを生成した時、Pythonのソースコード attr_value_pb2.py
が生成されます。生成された attr_value_pb2
モジュールを使って、Protocol Buffersで定義されたデータ構造にアクセスすることができます。
グラフを表現するデータ構造
TensorFlow内でProtocol Buffersで定義されるグラフを表現するデータ構造は、GraphDef
です。GraphDef
は、グラフのノードを表現するデータ構造 NodeDef
から構成されます。
GraphDef
GraphDef
は、TensorFlow内の計算グラフを表現するデータ構造で、Protocol Buffersで定義されます。
複数のノードから構成される計算グラフは、計算グラフに含まれるノードを表現するデータ構造 NodeDef
のリストを保持しています。また、GraphDef
は、TensorFlowのSavedModel 2 のグラフ情報を保存する時に使用する情報であるため、互換性の確認に必要なバージョン情報が含まれています。
message GraphDef {
repeated NodeDef node = 1;
VersionDef versions = 4;
int32 version = 3 [deprecated = true];
FunctionDefLibrary library = 2;
};
フィールド | 意味 |
---|---|
node |
グラフのノードを表現する NodeDef のリスト |
versions |
バージョン |
version |
バージョン(Deprecated) |
library |
ユーザ定義の関数 |
NodeDef
NodeDef
は、グラフのノードを表現するデータ構造で、Protocol Buffersで定義されます。
ノード名やノードに対応する演算(Operation)、入力ノードなど、TensorFlowの計算グラフを構築するために必要な情報が含まれています。
message NodeDef {
string name = 1;
string op = 2;
repeated string input = 3;
string device = 4;
map<string, AttrValue> attr = 5;
};
フィールド | 意味 |
---|---|
name |
ノードを一意に特定するための名前 |
op |
ノードに紐づいたOperation名 |
input |
入力ノードのリスト |
device |
ユーザからリクエストされた、演算を実行するデバイス |
attr |
属性情報 |
AttrValue
AttrValue
は、値を保持するデータ構造で、Protocol Buffersで定義されます。
oneof
修飾子により、ある1つのデータ型のデータを保持することができます。
message AttrValue {
message ListValue {
repeated bytes s = 2;
repeated int64 i = 3 [packed = true];
repeated float f = 4 [packed = true];
repeated bool b = 5 [packed = true];
repeated DataType type = 6 [packed = true];
repeated TensorShapeProto shape = 7;
repeated TensorProto tensor = 8;
repeated NameAttrList func = 9;
}
oneof value {
bytes s = 2;
int64 i = 3;
float f = 4;
bool b = 5;
DataType type = 6;
TensorShapeProto shape = 7;
TensorProto tensor = 8;
ListValue list = 1;
NameAttrList func = 10;
string placeholder = 9;
}
}
フィールド | 型 |
---|---|
s |
文字列 |
i |
整数値 |
f |
浮動小数点数 |
b |
ブール値 |
type |
DataType (TensorFlow内部で使用されるデータ型) |
shape |
TensorShapeProto (テンソルの形) |
tensor |
TensorProto (テンソルの実体) |
list |
リスト(ListValue ) |
func |
ユーザ定義関数 |
placeholder |
ユーザ定義関数内で利用されるPlaceholder |
テンソルを表現するデータ構造
TensorFlow内でProtocol Buffersで定義されるテンソルを表現するデータ構造は、TensorProto
です。TensorProto
は、テンソルの形を表現するデータ構造 TensorShapeProto
とテンソルの各要素のデータから構成されます。
TensorProto
TensorProto
は、テンソルを表現するためのデータ構造で、Protocol Buffersで定義されます。
フィールド dtype
によって TensorProto
に保存されるデータ型が決まり、dtype
に対応する <type>_val
に、テンソルの各要素の値が保持されます。(<type>
は各データ型)また、<type>_val
をシリアライズしたデータが tensor_content
に保持されています。
message TensorProto {
DataType dtype = 1;
TensorShapeProto tensor_shape = 2;
int32 version_number = 3;
bytes tensor_content = 4;
repeated int32 half_val = 13 [packed = true];
repeated float float_val = 5 [packed = true];
repeated double double_val = 6 [packed = true];
repeated int32 int_val = 7 [packed = true];
repeated bytes string_val = 8;
repeated float scomplex_val = 9 [packed = true];
repeated int64 int64_val = 10 [packed = true];
repeated bool bool_val = 11 [packed = true];
repeated double dcomplex_val = 12 [packed = true];
repeated ResourceHandleProto resource_handle_val = 14;
repeated VariantTensorDataProto variant_val = 15;
repeated uint32 uint32_val = 16 [packed = true];
repeated uint64 uint64_val = 17 [packed = true];
};
フィールド | 意味 |
---|---|
dtype |
データ型(DataType ) |
tensor_shape |
テンソルの形(TensorShapeProto ) |
version_number |
バージョン |
tensor_content |
シリアライズ形式のデータ |
TensorShapeProto
TensorShapeProto
は、TensorProto
で表現されるテンソルの形を表現する、Protocol Buffersで定義されたデータ構造です。
フィールド dim
がテンソルの形を表し、テンソルの各軸の要素数が保存されています。
message TensorShapeProto {
message Dim {
int64 size = 1;
string name = 2;
};
repeated Dim dim = 2;
bool unknown_rank;
};
ファンクションを表現するデータ構造
TensorFlowには、ファンクションと呼ばれる機能があります。ファンクションは、複数の演算を1つのノードとして見せることができる機能で、ユーザはデコレータtensorflow.python.framework.function.Defun()
を利用することで、任意のファンクションを定義できます。
FunctionDef
FunctionDef
は、ユーザが定義したファンクションを表現するデータ構造で、Protocol Buffersで定義されます。
ファンクションを一意に決定するために必要な名前(シグネイチャ)と、ファンクションの定義を示すためのノード構成情報を持ちます。reserved
は過去に存在していたフィールドですが、TensorFlowのバージョンアップとともに不要になったため、削除されたようです。互換性を取るため、予約領域としてフィールドに残っているものと思われます。
message FunctionDef {
OpDef signature = 1;
map<string, AttrValue> attr = 5;
reserved 2;
repeated NodeDef node_def = 3;
map<string, string> ret = 4;
};
フィールド | 意味 |
---|---|
signature |
ファンクションを一意に決定するシグネイチャ(OpDef ) |
attr |
属性情報 |
node_def |
ファンクションを構成するノード(NodeDef ) |
ret |
戻り値 |
演算を表現するデータ構造
TensorFlowの計算グラフは、NodeDef
のフィールド op
により、演算と1対1に結びついています。計算グラフは、演算を表現するデータ構造 OpDef
と KernelDef
と照合することで、NodeDef
で指定された演算が、TensorFlowでサポートされているかを確認します。
OpDef
OpDef
は、TensorFlowにおける演算(Operation)仕様を表現する、Protocol Buffersで定義されたデータ構造です。
ユーザが定義した計算グラフの各ノード(NodeDef
)で定義された演算が、OpDef
で定義された演算仕様を満たしているかを確認します。OpDef
で定義された演算仕様を満たさない演算を要求すると、プログラムが異常終了します。
入出力データを示すフィールド input_arg
や output_arg
に加えて、グラフ最適化処理などで利用されるフィールドが定義されています。
message OpDef {
string name = 1;
repeated ArgDef input_arg = 2;
repeated ArgDef output_arg = 3;
repeated AttrDef attr = 4;
OpDeprecation deprecation = 8;
string summary = 5;
string description = 6;
bool is_commutative = 18;
bool is_aggregate = 16;
bool is_stateful = 17;
bool allows_uninitialized_input = 19;
};
フィールド | 意味 |
---|---|
name |
演算名 |
input_arg |
入力データ |
output_arg |
出力データ |
attr |
属性情報 |
deprecation |
Deprecatedであるデータ |
summary |
演算の説明 |
description |
演算の説明(summary の長文版) |
is_commutative |
入力データの入力番号を交換可能な演算の場合は、true
|
aggregate |
複数の入力データから1つの出力データを出力する演算の場合は、true
|
is_stateful |
ステートフルな演算の場合は、true
|
allows_uninitialized_input |
未初期化の入力データを許す場合は true
|
KernelDef
KernelDef
は、演算本体(OpKernel
)の定義を表現する、Protocol Buffersで定義されたデータ構造です。
ユーザが定義した計算グラフの各ノード(NodeDef
)の情報を KernelDef
の情報と照合させ、マッチした KernelDef
に対応する、演算本体であるクラス OpKernel
を取得します。
NodeDef
で定義した演算に一致する特定の OpKernel
を取得する必要があるため、KernelDef
にはデバイス情報や制約条件が保存されています。仮に、NodeDef
に保持された情報を満たす KernelDef
が存在しない場合は、プログラムが異常終了します。
message KernelDef {
string op = 1;
string device_type = 2;
message AttrConstraint {
string name = 1;
AttrValue allowed_values = 2;
}
repeated AttrConstraint constraint = 3;
repeated string host_memory_arg = 4;
string label = 5;
};
フィールド | 意味 |
---|---|
op |
演算名 |
device_type |
演算実行デバイス |
constraint |
属性情報の制約条件 |
host_memory_arg |
ホストメモリを利用する入力データ/出力データ |
label |
(不明) |