Session#run
def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
"""Runs the operations and evaluates the tensors in `fetches`.
Args:
fetches: A single graph element, or a list of graph elements
(described above).
feed_dict: A dictionary that maps graph elements to values
(described above).
options: A [`RunOptions`] protocol buffer
run_metadata: A [`RunMetadata`] protocol buffer
Returns:
Either a single value if `fetches` is a single graph element, or
a list of values if `fetches` is a list (described above).
"""
run_metadata_ptr = tf_session.TF_NewBuffer()
if options:
options_ptr = tf_session.TF_NewBufferFromString(
compat.as_bytes(options.SerializeToString()))
else:
options_ptr = None
try:
result = self._run(None, fetches, feed_dict, options_ptr,
run_metadata_ptr)
if run_metadata:
proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
run_metadata.ParseFromString(compat.as_bytes(proto_data))
finally:
tf_session.TF_DeleteBuffer(run_metadata_ptr)
if options:
tf_session.TF_DeleteBuffer(options_ptr)
return result
簡単に考えるため、引数はfetchesだけと仮定する。
_run
で実行した結果を返しているだけのようだ。
_run
を少しずつ見ていきます。
Session#_run
def _run(self, handle, fetches, feed_dict, options, run_metadata):
"""Perform either run or partial_run, depending the exitence of `handle`."""
# Check session.
if self._closed:
raise RuntimeError('Attempted to use a closed Session.')
if self.graph.version == 0:
raise RuntimeError('The Session graph is empty. Add operations to the '
'graph before calling run().')
# Validate and process fetches.
processed_fetches = self._process_fetches(fetches)
unique_fetches = processed_fetches[0]
target_list = processed_fetches[1]
fetch_info = processed_fetches[2]
unique_handles = processed_fetches[3]
引数はfetchesのみと仮定します。
最初にSessionが終了状態かを確認しています。
次の処理であるself._process_fetches
を見てみます。
Session#_process_fetches
def _process_fetches(self, fetches):
"""Validate and process fetches."""
def _fetch_fn(fetch):
for tensor_type, fetch_fn, _, _ in BaseSession._REGISTERED_EXPANSIONS:
if isinstance(fetch, tensor_type):
return fetch_fn(fetch)
raise TypeError('Fetch argument %r has invalid type %r'
% (fetch, type(fetch)))
最初に内部で使う関数を定義しています。
fetcheが定義されている型に含まれていれば、fetch_fn
が実行されます。
前もって定義されている型がこちら。
_REGISTERED_EXPANSIONS = [
# SparseTensors are fetched as SparseTensorValues. They can be fed
# SparseTensorValues or normal tuples.
(ops.SparseTensor,
lambda fetch: (
[fetch.indices, fetch.values, fetch.shape],
lambda fetched_vals: ops.SparseTensorValue(*fetched_vals)),
lambda feed, feed_val: list(zip(
[feed.indices, feed.values, feed.shape], feed_val)),
lambda feed: [feed.indices, feed.values, feed.shape]),
# IndexedSlices are fetched as IndexedSlicesValues. They can be fed
# IndexedSlicesValues or normal tuples.
(ops.IndexedSlices,
lambda fetch: (
[fetch.values, fetch.indices] if fetch.dense_shape is None
else [fetch.values, fetch.indices, fetch.dense_shape],
_get_indexed_slices_value_from_fetches),
_get_feeds_for_indexed_slices,
lambda feed: [feed.values, feed.indices] if feed.dense_shape is None
else [feed.values, feed.indices, feed.dense_shape]),
# The default catches all types and performs no expansions.
(object,
lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]),
lambda feed, feed_val: [(feed, feed_val)],
lambda feed: [feed])]
型と3つのlambda式が定義されています。
# Validate and process fetches.
is_list_fetch = isinstance(fetches, (list, tuple))
if not is_list_fetch:
fetches = [fetches]
unique_fetch_targets = set()
unique_fetch_handles = {}
target_list = []
fetch_info = []
関数を定義した後は、fetchesがlistまたはtupleであることを確認しています。
listまたはtupleでない場合は、強制的にlistに変換しています。
入れ物を用意した後に、fetchesの計算に入ります。
for fetch in fetches:
subfetches, fetch_contraction_fn = _fetch_fn(fetch)
subfetch_names = []
先ほど定義した関数を実行します。
subfetches=[fetch.indices, fetch.values, fetch.shape]
,
fetch_contraction_fn=lambda fetched_vals:ps.SparseTensorValue(*fetched_vals)
となっています。
for subfetch in subfetches:
try:
fetch_t = self.graph.as_graph_element(subfetch, allow_tensor=True,
allow_operation=True)
fetch_name = compat.as_bytes(fetch_t.name)
if isinstance(fetch_t, ops.Operation):
target_list.append(fetch_name)
else:
subfetch_names.append(fetch_name)
# Remember the fetch if it is for a tensor handle.
if (isinstance(fetch_t, ops.Tensor) and
fetch_t.op.type == 'GetSessionHandle'):
unique_fetch_handles[fetch_name] = fetch_t.op.inputs[0].dtype
except TypeError as e:
raise TypeError('Fetch argument %r of %r has invalid type %r, '
'must be a string or Tensor. (%s)'
% (subfetch, fetch, type(subfetch), str(e)))
except ValueError as e:
raise ValueError('Fetch argument %r of %r cannot be interpreted as a '
'Tensor. (%s)' % (subfetch, fetch, str(e)))
except KeyError as e:
raise ValueError('Fetch argument %r of %r cannot be interpreted as a '
'Tensor. (%s)' % (subfetch, fetch, str(e)))
fetcheの型によって、入れる箱を分けているようです。
unique_fetch_targets.update(subfetch_names)
fetch_info.append((subfetch_names, fetch_contraction_fn))
unique_fetch_targets
にsubfetch_names
を展開して加えています。
fetch_info
に(subfetch_names, fetch_contraction_fn)
のtupleを加えています。
unique_fetch_targets = list(unique_fetch_targets)
return unique_fetch_targets, target_list, fetch_info, unique_fetch_handles
unique_fetch_targets
がsetなのでlistに変換しています。
最終的に各入れ物を返しています。
Graph#as_graph_element
fetch_t
に入るデータを返すので見てみます。
def as_graph_element(self, obj, allow_tensor=True, allow_operation=True):
"""Returns the object referred to by `obj`, as an `Operation` or `Tensor`.
Args:
obj: A `Tensor`, an `Operation`, or the name of a tensor or operation.
Can also be any object with an `_as_graph_element()` method that returns
a value of one of these types.
allow_tensor: If true, `obj` may refer to a `Tensor`.
allow_operation: If true, `obj` may refer to an `Operation`.
Returns:
The `Tensor` or `Operation` in the Graph corresponding to `obj`.
commentから引数となるobjectをTensor
またはOperation
に変換するものだとわかります。
if allow_tensor and allow_operation:
types_str = "Tensor or Operation"
elif allow_tensor:
types_str = "Tensor"
elif allow_operation:
types_str = "Operation"
else:
raise ValueError("allow_tensor and allow_operation can't both be False.")
引数のbooleanによって型の文字列が変わります。
temp_obj = _as_graph_element(obj)
if temp_obj is not None:
obj = temp_obj
_as_graph_element
を見てみます。
def _as_graph_element(obj):
"""Convert `obj` to a graph element if possible, otherwise return `None`.
Args:
obj: Object to convert.
Returns:
The result of `obj._as_graph_element()` if that method is available;
otherwise `None`.
"""
conv_fn = getattr(obj, "_as_graph_element", None)
if conv_fn and callable(conv_fn):
return conv_fn()
return None
fetchに_as_graph_element
が実装されていればそれを呼び出しています。
ちなみに、Tutorialでよく見るtf.Variable()
に_as_graph_element
は実装されていました。
# If obj appears to be a name...
if isinstance(obj, compat.bytes_or_text_types): ...
elif isinstance(obj, Tensor) and allow_tensor:
# Actually obj is just the object it's referring to.
if obj.graph is not self:
raise ValueError("Tensor %s is not an element of this graph." % obj)
return obj
elif isinstance(obj, Operation) and allow_operation:
# Actually obj is just the object it's referring to.
if obj.graph is not self:
raise ValueError("Operation %s is not an element of this graph." % obj)
return obj
else:
# We give up!
raise TypeError("Can not convert a %s into a %s."
% (type(obj).__name__, types_str))
obj.graph
がGraph
でなければエラーとなります。
エラーがなければobjをそのまま返しています。
fetchを処理するところまでは大まかですが、何をしているかわかりました。