LoginSignup
2
2

More than 5 years have passed since last update.

TensorFlowを読み解いていく~Session作成編~

Posted at

Library探索

機械学習をすぐに習得するのは難しいので、TensorFlowをライブラリとして見ていきたいと思います。

全般

基本的にtensorflow/python配下のclass,methodは全てimportされています。
一部、C++のコードがBazelというツールでPythonライブラリとしてビルドされているようです。

Session

計算を実行するクラス。
https://github.com/tensorflow/tensorflow/blob/r0.8/tensorflow/python/client/session.py

Session
class Session(BaseSession):
  def __init__(self, target='', graph=None, config=None):
    super(Session, self).__init__(target, graph, config=config)
    self._context_managers = [self.graph.as_default(), self.as_default()]

baseクラスの__init__`を見ます。

BaseSession
class BaseSession(SessionInterface):
  def __init__(self, target='', graph=None, config=None):
    if graph is None:
      self._graph = ops.get_default_graph()
    else:
      self._graph = graph
    ...,
    opts = tf_session.TF_NewSessionOptions(target=target, config=config)
    try:
      status = tf_session.TF_NewStatus()
      try:
        self._session = tf_session.TF_NewSession(opts, status)
        if tf_session.TF_GetCode(status) != 0:
          raise RuntimeError(compat.as_text(tf_session.TF_Message(status)))
      finally:
        tf_session.TF_DeleteStatus(status)
    finally:
      tf_session.TF_DeleteSessionOptions(opts)

graphとsessionを設定していることがわかる。

import
from tensorflow.python.framework import ops
from tensorflow.python import pywrap_tensorflow as tf_session

opsとtf_sessionを定義しているファイルを見てみる。

ops.py

ops.py
class _DefaultStack(threading.local):
  def __init__(self):
    super(_DefaultStack, self).__init__()
    self.stack = []

  def get_default(self):
    return self.stack[-1] if len(self.stack) >= 1 else None

class _DefaultGraphStack(_DefaultStack):
  def __init__(self):
    super(_DefaultGraphStack, self).__init__()
    self._global_default_graph = None

  def get_default(self):
    ret = super(_DefaultGraphStack, self).get_default()
    if ret is None:
      ret = self._GetGlobalDefaultGraph()
    return ret

  def _GetGlobalDefaultGraph(self):
    if self._global_default_graph is None:
      self._global_default_graph = Graph()
    return self._global_default_graph

_default_graph_stack = _DefaultGraphStack()

def get_default_graph():
  return _default_graph_stack.get_default()

graphにはGraph()が入りそうです。

pywrap_tensorflowは少しややこしくて、Bazelというビルドツールでビルドされるファイルのようです。

BUILD

tf_py_wrap_cc(
    name = "pywrap_tensorflow",
    srcs = ["tensorflow.i"],
    ...,
)

tensorflow.i

%include "tensorflow/python/client/tf_session.i"

tf_session.i
これはどうやらswigというC/C++の機能を他言語に組み込む定義ファイルのようです。
http://www.swig.org/Doc3.0/SWIGDocumentation.html#SWIG_nn2

// Include the functions from tensor_c_api.h, except TF_Run.
%unignore TF_NewStatus;
%unignore TF_DeleteStatus;
%unignore TF_GetCode;
%unignore TF_Message;
%rename("_TF_SetTarget") TF_SetTarget;
%rename("_TF_NewSessionOptions") TF_NewSessionOptions;
%unignore TF_DeleteSessionOptions;
%include "tensorflow/core/public/tensor_c_api.h"

%insert("python") %{
  def TF_NewSessionOptions(target=None, config=None):
    opts = _TF_NewSessionOptions()
    if target is not None:
      from tensorflow.python.util import compat
      _TF_SetTarget(opts, compat.as_bytes(target))
    if config is not None:
      from tensorflow.core.protobuf import config_pb2
      if not isinstance(config, config_pb2.ConfigProto):
        raise TypeError("Expected config_pb2.ConfigProto, "
                        "but got %s" % type(config))
      status = TF_NewStatus()
      config_str = config.SerializeToString()
      _TF_SetConfig(opts, config_str, status)
      if TF_GetCode(status) != 0:
        raise ValueError(TF_Message(status))
    return opts
%}

重要そうな_TF_NewSessionOptionsTF_NewSessionを見ていきます。
https://github.com/tensorflow/tensorflow/blob/r0.8/tensorflow/core/client/tensor_c_api.cc

tensor_c_api.cc
struct TF_SessionOptions {
  SessionOptions options;
};
TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; }

struct TF_Session {
  Session* session;
};

TF_Session* TF_NewSession(const TF_SessionOptions* opt, TF_Status* status) {
  Session* session;
  status->status = NewSession(opt->options, &session);
  if (status->status.ok()) {
    return new TF_Session({session});
  } else {
    DCHECK_EQ(nullptr, session);
    return NULL;
  }
}

Session自体はここで定義されています。
https://github.com/tensorflow/tensorflow/blob/r0.8/tensorflow/core/common_runtime/session.cc

session.cc
Session* NewSession(const SessionOptions& options) {
  SessionFactory* factory;
  Status s = SessionFactory::GetFactory(options, &factory);
  if (!s.ok()) {
    LOG(ERROR) << s;
    return nullptr;
  }
  return factory->NewSession(options);
}

Status NewSession(const SessionOptions& options, Session** out_session) {
  SessionFactory* factory;
  Status s = SessionFactory::GetFactory(options, &factory);
  if (!s.ok()) {
    *out_session = nullptr;
    LOG(ERROR) << s;
    return s;
  }
  *out_session = factory->NewSession(options);
  if (!*out_session) {
    return errors::Internal("Failed to create session.");
  }
  return Status::OK();
}

https://github.com/tensorflow/tensorflow/blob/r0.8/tensorflow/core/public/session.h#L81
ざっくり言ってここのSessionインスタンスをを作っていそう。

pythonの方に戻り、self.graph.as_default(), self.as_default()を見てみます。

ops.py
class _DefaultStack(threading.local):

  @contextlib.contextmanager
  def get_controller(self, default):
    try:
      self.stack.append(default)
      yield default
    finally:
      assert self.stack[-1] is default
      self.stack.pop()

class Graph(object):

  def as_default(self):
    return _default_graph_stack.get_controller(self)

graphの方は自分自身を呼び出しているだけのようです。

session.py
class BaseSession(SessionInterface):

  def as_default(self):
    return ops.default_session(self)
ops.py
_default_session_stack = _DefaultStack()

def default_session(session):
  return _default_session_stack.get_controller(weakref.ref(session))

弱参照で自分自身を呼び出しています。

まとめ

Sessionはgraphとsessionを設定して、context_managerにgraphと自分自身を設定していることがわかりました。

Session().runに続く・・・(完)

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2