Library探索
機械学習をすぐに習得するのは難しいので、TensorFlowをライブラリとして見ていきたいと思います。
全般
基本的にtensorflow/python
配下のclass,methodは全てimportされています。
一部、C++のコードがBazelというツールでPythonライブラリとしてビルドされているようです。
Session
計算を実行するクラス。
https://github.com/tensorflow/tensorflow/blob/r0.8/tensorflow/python/client/session.py
class Session(BaseSession):
def __init__(self, target='', graph=None, config=None):
super(Session, self).__init__(target, graph, config=config)
self._context_managers = [self.graph.as_default(), self.as_default()]
baseクラスの`init``を見ます。
class BaseSession(SessionInterface):
def __init__(self, target='', graph=None, config=None):
if graph is None:
self._graph = ops.get_default_graph()
else:
self._graph = graph
...,
opts = tf_session.TF_NewSessionOptions(target=target, config=config)
try:
status = tf_session.TF_NewStatus()
try:
self._session = tf_session.TF_NewSession(opts, status)
if tf_session.TF_GetCode(status) != 0:
raise RuntimeError(compat.as_text(tf_session.TF_Message(status)))
finally:
tf_session.TF_DeleteStatus(status)
finally:
tf_session.TF_DeleteSessionOptions(opts)
graphとsessionを設定していることがわかる。
from tensorflow.python.framework import ops
from tensorflow.python import pywrap_tensorflow as tf_session
opsとtf_sessionを定義しているファイルを見てみる。
class _DefaultStack(threading.local):
def __init__(self):
super(_DefaultStack, self).__init__()
self.stack = []
def get_default(self):
return self.stack[-1] if len(self.stack) >= 1 else None
class _DefaultGraphStack(_DefaultStack):
def __init__(self):
super(_DefaultGraphStack, self).__init__()
self._global_default_graph = None
def get_default(self):
ret = super(_DefaultGraphStack, self).get_default()
if ret is None:
ret = self._GetGlobalDefaultGraph()
return ret
def _GetGlobalDefaultGraph(self):
if self._global_default_graph is None:
self._global_default_graph = Graph()
return self._global_default_graph
_default_graph_stack = _DefaultGraphStack()
def get_default_graph():
return _default_graph_stack.get_default()
graphにはGraph()
が入りそうです。
pywrap_tensorflow
は少しややこしくて、Bazelというビルドツールでビルドされるファイルのようです。
tf_py_wrap_cc(
name = "pywrap_tensorflow",
srcs = ["tensorflow.i"],
...,
)
%include "tensorflow/python/client/tf_session.i"
tf_session.i
これはどうやらswigというC/C++の機能を他言語に組み込む定義ファイルのようです。
http://www.swig.org/Doc3.0/SWIGDocumentation.html#SWIG_nn2
// Include the functions from tensor_c_api.h, except TF_Run.
%unignore TF_NewStatus;
%unignore TF_DeleteStatus;
%unignore TF_GetCode;
%unignore TF_Message;
%rename("_TF_SetTarget") TF_SetTarget;
%rename("_TF_NewSessionOptions") TF_NewSessionOptions;
%unignore TF_DeleteSessionOptions;
%include "tensorflow/core/public/tensor_c_api.h"
%insert("python") %{
def TF_NewSessionOptions(target=None, config=None):
opts = _TF_NewSessionOptions()
if target is not None:
from tensorflow.python.util import compat
_TF_SetTarget(opts, compat.as_bytes(target))
if config is not None:
from tensorflow.core.protobuf import config_pb2
if not isinstance(config, config_pb2.ConfigProto):
raise TypeError("Expected config_pb2.ConfigProto, "
"but got %s" % type(config))
status = TF_NewStatus()
config_str = config.SerializeToString()
_TF_SetConfig(opts, config_str, status)
if TF_GetCode(status) != 0:
raise ValueError(TF_Message(status))
return opts
%}
重要そうな_TF_NewSessionOptions
とTF_NewSession
を見ていきます。
https://github.com/tensorflow/tensorflow/blob/r0.8/tensorflow/core/client/tensor_c_api.cc
struct TF_SessionOptions {
SessionOptions options;
};
TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; }
struct TF_Session {
Session* session;
};
TF_Session* TF_NewSession(const TF_SessionOptions* opt, TF_Status* status) {
Session* session;
status->status = NewSession(opt->options, &session);
if (status->status.ok()) {
return new TF_Session({session});
} else {
DCHECK_EQ(nullptr, session);
return NULL;
}
}
Session自体はここで定義されています。
https://github.com/tensorflow/tensorflow/blob/r0.8/tensorflow/core/common_runtime/session.cc
Session* NewSession(const SessionOptions& options) {
SessionFactory* factory;
Status s = SessionFactory::GetFactory(options, &factory);
if (!s.ok()) {
LOG(ERROR) << s;
return nullptr;
}
return factory->NewSession(options);
}
Status NewSession(const SessionOptions& options, Session** out_session) {
SessionFactory* factory;
Status s = SessionFactory::GetFactory(options, &factory);
if (!s.ok()) {
*out_session = nullptr;
LOG(ERROR) << s;
return s;
}
*out_session = factory->NewSession(options);
if (!*out_session) {
return errors::Internal("Failed to create session.");
}
return Status::OK();
}
https://github.com/tensorflow/tensorflow/blob/r0.8/tensorflow/core/public/session.h#L81
ざっくり言ってここのSessionインスタンスをを作っていそう。
pythonの方に戻り、self.graph.as_default(), self.as_default()
を見てみます。
class _DefaultStack(threading.local):
@contextlib.contextmanager
def get_controller(self, default):
try:
self.stack.append(default)
yield default
finally:
assert self.stack[-1] is default
self.stack.pop()
class Graph(object):
def as_default(self):
return _default_graph_stack.get_controller(self)
graphの方は自分自身を呼び出しているだけのようです。
class BaseSession(SessionInterface):
def as_default(self):
return ops.default_session(self)
_default_session_stack = _DefaultStack()
def default_session(session):
return _default_session_stack.get_controller(weakref.ref(session))
弱参照で自分自身を呼び出しています。
まとめ
Sessionはgraphとsessionを設定して、context_managerにgraphと自分自身を設定していることがわかりました。
Session().runに続く・・・(完)