LoginSignup
2
3

More than 5 years have passed since last update.

TensorFlowを読み解いていく [Session#run編1]

Last updated at Posted at 2016-05-08

Session#run

Session#run
  def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
    """Runs the operations and evaluates the tensors in `fetches`.
    Args:
      fetches: A single graph element, or a list of graph elements
        (described above).
      feed_dict: A dictionary that maps graph elements to values
        (described above).
      options: A [`RunOptions`] protocol buffer
      run_metadata: A [`RunMetadata`] protocol buffer

    Returns:
      Either a single value if `fetches` is a single graph element, or
      a list of values if `fetches` is a list (described above).
    """
    run_metadata_ptr = tf_session.TF_NewBuffer()
    if options:
      options_ptr = tf_session.TF_NewBufferFromString(
          compat.as_bytes(options.SerializeToString()))
    else:
      options_ptr = None

    try:
      result = self._run(None, fetches, feed_dict, options_ptr,
                         run_metadata_ptr)
      if run_metadata:
        proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
        run_metadata.ParseFromString(compat.as_bytes(proto_data))
    finally:
      tf_session.TF_DeleteBuffer(run_metadata_ptr)
      if options:
        tf_session.TF_DeleteBuffer(options_ptr)
    return result

簡単に考えるため、引数はfetchesだけと仮定する。
_runで実行した結果を返しているだけのようだ。
_runを少しずつ見ていきます。

Session#_run

Session#_run
  def _run(self, handle, fetches, feed_dict, options, run_metadata):
    """Perform either run or partial_run, depending the exitence of `handle`."""

    # Check session.
    if self._closed:
      raise RuntimeError('Attempted to use a closed Session.')
    if self.graph.version == 0:
      raise RuntimeError('The Session graph is empty.  Add operations to the '
                         'graph before calling run().')

    # Validate and process fetches.
    processed_fetches = self._process_fetches(fetches)
    unique_fetches = processed_fetches[0]
    target_list = processed_fetches[1]
    fetch_info = processed_fetches[2]
    unique_handles = processed_fetches[3]

引数はfetchesのみと仮定します。
最初にSessionが終了状態かを確認しています。

次の処理であるself._process_fetchesを見てみます。

Session#_process_fetches

BasicSession#_process_fetches
  def _process_fetches(self, fetches):
    """Validate and process fetches."""
    def _fetch_fn(fetch):
      for tensor_type, fetch_fn, _, _ in BaseSession._REGISTERED_EXPANSIONS:
        if isinstance(fetch, tensor_type):
          return fetch_fn(fetch)
      raise TypeError('Fetch argument %r has invalid type %r'
                      % (fetch, type(fetch)))

最初に内部で使う関数を定義しています。
fetcheが定義されている型に含まれていれば、fetch_fnが実行されます。

前もって定義されている型がこちら。

BasicSession#_REGISTERED_EXPANSIONS
  _REGISTERED_EXPANSIONS = [
      # SparseTensors are fetched as SparseTensorValues. They can be fed
      # SparseTensorValues or normal tuples.
      (ops.SparseTensor,
       lambda fetch: (
           [fetch.indices, fetch.values, fetch.shape],
           lambda fetched_vals: ops.SparseTensorValue(*fetched_vals)),
       lambda feed, feed_val: list(zip(
           [feed.indices, feed.values, feed.shape], feed_val)),
       lambda feed: [feed.indices, feed.values, feed.shape]),
      # IndexedSlices are fetched as IndexedSlicesValues. They can be fed
      # IndexedSlicesValues or normal tuples.
      (ops.IndexedSlices,
       lambda fetch: (
           [fetch.values, fetch.indices] if fetch.dense_shape is None
           else [fetch.values, fetch.indices, fetch.dense_shape],
           _get_indexed_slices_value_from_fetches),
       _get_feeds_for_indexed_slices,
       lambda feed: [feed.values, feed.indices] if feed.dense_shape is None
                    else [feed.values, feed.indices, feed.dense_shape]),
      # The default catches all types and performs no expansions.
      (object,
       lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]),
       lambda feed, feed_val: [(feed, feed_val)],
       lambda feed: [feed])]

型と3つのlambda式が定義されています。

BasicSession#_process_fetches
    # Validate and process fetches.
    is_list_fetch = isinstance(fetches, (list, tuple))
    if not is_list_fetch:
      fetches = [fetches]

    unique_fetch_targets = set()
    unique_fetch_handles = {}
    target_list = []

    fetch_info = []

関数を定義した後は、fetchesがlistまたはtupleであることを確認しています。
listまたはtupleでない場合は、強制的にlistに変換しています。

入れ物を用意した後に、fetchesの計算に入ります。

BasicSession#_process_fetches
    for fetch in fetches:
      subfetches, fetch_contraction_fn = _fetch_fn(fetch)
      subfetch_names = []

先ほど定義した関数を実行します。
subfetches=[fetch.indices, fetch.values, fetch.shape],
fetch_contraction_fn=lambda fetched_vals:ps.SparseTensorValue(*fetched_vals)となっています。

BasicSession#_process_fetches
      for subfetch in subfetches:
        try:
          fetch_t = self.graph.as_graph_element(subfetch, allow_tensor=True,
                                                allow_operation=True)
          fetch_name = compat.as_bytes(fetch_t.name)
          if isinstance(fetch_t, ops.Operation):
            target_list.append(fetch_name)
          else:
            subfetch_names.append(fetch_name)
          # Remember the fetch if it is for a tensor handle.
          if (isinstance(fetch_t, ops.Tensor) and
              fetch_t.op.type == 'GetSessionHandle'):
            unique_fetch_handles[fetch_name] = fetch_t.op.inputs[0].dtype
        except TypeError as e:
          raise TypeError('Fetch argument %r of %r has invalid type %r, '
                          'must be a string or Tensor. (%s)'
                          % (subfetch, fetch, type(subfetch), str(e)))
        except ValueError as e:
          raise ValueError('Fetch argument %r of %r cannot be interpreted as a '
                           'Tensor. (%s)' % (subfetch, fetch, str(e)))
        except KeyError as e:
          raise ValueError('Fetch argument %r of %r cannot be interpreted as a '
                           'Tensor. (%s)' % (subfetch, fetch, str(e)))

fetcheの型によって、入れる箱を分けているようです。

BasicSession#_process_fetches
      unique_fetch_targets.update(subfetch_names)
      fetch_info.append((subfetch_names, fetch_contraction_fn))

unique_fetch_targetssubfetch_namesを展開して加えています。
fetch_info(subfetch_names, fetch_contraction_fn)のtupleを加えています。

BasicSession#_process_fetches
    unique_fetch_targets = list(unique_fetch_targets)
    return unique_fetch_targets, target_list, fetch_info, unique_fetch_handles

unique_fetch_targetsがsetなのでlistに変換しています。
最終的に各入れ物を返しています。

Graph#as_graph_element

fetch_tに入るデータを返すので見てみます。

Graph#as_graph_element
  def as_graph_element(self, obj, allow_tensor=True, allow_operation=True):
    """Returns the object referred to by `obj`, as an `Operation` or `Tensor`.

    Args:
      obj: A `Tensor`, an `Operation`, or the name of a tensor or operation.
        Can also be any object with an `_as_graph_element()` method that returns
        a value of one of these types.
      allow_tensor: If true, `obj` may refer to a `Tensor`.
      allow_operation: If true, `obj` may refer to an `Operation`.
    Returns:
      The `Tensor` or `Operation` in the Graph corresponding to `obj`.

commentから引数となるobjectをTensorまたはOperationに変換するものだとわかります。

Graph#as_graph_element
    if allow_tensor and allow_operation:
      types_str = "Tensor or Operation"
    elif allow_tensor:
      types_str = "Tensor"
    elif allow_operation:
      types_str = "Operation"
    else:
      raise ValueError("allow_tensor and allow_operation can't both be False.")

引数のbooleanによって型の文字列が変わります。

Graph#as_graph_element
    temp_obj = _as_graph_element(obj)
    if temp_obj is not None:
      obj = temp_obj

_as_graph_elementを見てみます。

Graph#_as_graph_element
def _as_graph_element(obj):
  """Convert `obj` to a graph element if possible, otherwise return `None`.
  Args:
    obj: Object to convert.
  Returns:
    The result of `obj._as_graph_element()` if that method is available;
        otherwise `None`.
  """
  conv_fn = getattr(obj, "_as_graph_element", None)
  if conv_fn and callable(conv_fn):
    return conv_fn()
  return None

fetchに_as_graph_elementが実装されていればそれを呼び出しています。
ちなみに、Tutorialでよく見るtf.Variable()_as_graph_elementは実装されていました。

Graph#as_graph_element
    # If obj appears to be a name...
    if isinstance(obj, compat.bytes_or_text_types): ...
    elif isinstance(obj, Tensor) and allow_tensor:
      # Actually obj is just the object it's referring to.
      if obj.graph is not self:
        raise ValueError("Tensor %s is not an element of this graph." % obj)
      return obj
    elif isinstance(obj, Operation) and allow_operation:
      # Actually obj is just the object it's referring to.
      if obj.graph is not self:
        raise ValueError("Operation %s is not an element of this graph." % obj)
      return obj
    else:
      # We give up!
      raise TypeError("Can not convert a %s into a %s."
                      % (type(obj).__name__, types_str))

obj.graphGraphでなければエラーとなります。
エラーがなければobjをそのまま返しています。

fetchを処理するところまでは大まかですが、何をしているかわかりました。

2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3