C++
CUDA
kernel
TensorFlow

【TensorFlow】カーネル作成プロセスからGPU(CUDA)のアタッチまでの大まかな流れを解析し, Multi CPU, Multi GPUの現状を調査 - Part1

More than 1 year has passed since last update.

1. TensorFlowのカーネル起動の仕組みについて

TensorFlowでは、以下の仕組みとなっていました。
- 静的なグローバル変数を定義することでプロセス起動と同時に自動でスレッドを作成し、自動でカーネル環境を作成する
⇒ 将来は、オプション設定できるようにAPIとして穴あけされるだろう思われる。
- GPUの言語は、CUDAを使用しておりNVIDIAのGPUを使用している事を前提としている (Dual対応)
- C++ 11のコード規約に基づいた記述となっている。(auto型など)
- コマンドを見る限りでは、Linux Kernel環境を想定した作りとなっている

CUDAについて補足:

NVIDIA製の並列計算アーキテクチャの総称である。
詳細は、以下の通り
URL: CUDA

・GPUの設定方法
公式で公開されているGPUの設定方法

※本記事は、出だしから間違えているので、修正します。

2. 大まかなシーケンス

現状のソースで、大まかなシーケンスをソースで追っていきます。
C++のクラス設計に結構活かせるノウハウも詰まっているので、順番に見ていきましょう。
コードリーディングが、貴方のコーディング力アップに繋がることは間違いないです。
※クラスポインタなどの概念も普通に出てきますが、変数のポインタと同じ考え方で捉えると分かりやすいでしょう。
実際に、関数名は変数名の用に扱えます。
簡単なケースでは、以下の通りになります。

sample.cpp
  // 何の変哲もないfunc1です。
  void func1(void* pData) {
    func1("func1");
  }
  // 何の変哲もないfunc2です。
  void func2(void* pData) {
    func1("func2");
  }

  // 共通の形式の関数をポインタを使って纏めます
  void (*pFuncList[])(void* pData) = {
    func1,  // func1を指定することで、func1の関数ポインタが渡されます。
    func2   // func2を指定することで、func2の関数ポインタが渡されます。
  }

  void main(void) {
    char* pBuffer = "Test";
    for(int i = 0; i < sizeof(pFuncList) / sizeof(pFuncList[0])) {
      pFuncList[++i](pBuffer);
    }
  }

それでは、本番のコードに入ります。
まずは、セッションの確立をしている部分から見ていきましょう。
※コード中の日本語のコメントが、今回のシーケンスです。
実際は、多くのコードでコメントはありません。

2.0 スタート地点
Source URL: direct_session.cc
Line. 554 - 572

direct_session.cc
class DirectSessionFactory : public SessionFactory {
 public:
  DirectSessionFactory() {} // 現状は、こちらに通って何もしていない

  // 将来は、オプション設定して下の関数を通したいらしい…
  // 以下を作動させれば、Dual CPU及びDual GPUを利用したAI 実行が可能になります。
  // 穴あけが大変そうですね  :disappointed_relieved: 
  Session* NewSession(const SessionOptions& options) override {
    std::vector<Device*> devices;
    DeviceFactory::AddDevices(options, "/job:localhost/replica:0/task:0",
                              &devices);
    return new DirectSession(options, new DeviceMgr(devices));
  }
};

class DirectSessionRegistrar {
 public:
  DirectSessionRegistrar() {
    SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory());
  }
};
// グローバルの静的変数として定義されている。
// これが、現状のセッションの仕組みのスタートです。
// デフォルトコンストラクタでは、何もしないので起動していない。
// 現状は、別のセッションが動いているのでしょうか…
static DirectSessionRegistrar registrar;

2.1 スレッドの作成と起動
2.0で機能していないことが分かっていますが、とりあえずoptionありで機能すると仮定してのシーケンスを見ていきましょう。

まずは、NewThreadPoolというローカル関数で、スレッドを作成します。
Line: 52 - 70

direct_session.cc
// 並列処理スレッドを作成します。
thread::ThreadPool* NewThreadPool(const SessionOptions& options) {
  int32 inter_op_parallelism_threads =
      options.config.inter_op_parallelism_threads();
  if (inter_op_parallelism_threads == 0) {
    // Default to using the number of cores available in the process.
    inter_op_parallelism_threads = port::NumSchedulableCPUs();
  }
  VLOG(1) << "Direct session inter op parallelism threads: "
          << inter_op_parallelism_threads;
  return new thread::ThreadPool(options.env, "Compute",
                                inter_op_parallelism_threads);
}

thread::ThreadPool* GlobalThreadPool(const SessionOptions& options) {
  static thread::ThreadPool* const thread_pool = NewThreadPool(options);
  return thread_pool;
}

2.2 スレッドからの割込み
以下の部分でスレッドに登録したセッションを実行します。
Line: 208 - 427

direct_session.cc
Status DirectSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
                          const std::vector<string>& output_names,
                          const std::vector<string>& target_nodes,
                          std::vector<Tensor>* outputs) {
  {
    mutex_lock l(graph_def_lock_);
    if (!graph_created_) {
      return errors::InvalidArgument(
          "Session was not created with a graph before Run()!");
    }
  }

  // Extract the inputs names for this run of the session.
  std::vector<string> input_tensor_names;
  input_tensor_names.reserve(inputs.size());
  for (const auto& it : inputs) {
    input_tensor_names.push_back(it.first);
  }

  // Check if we already have an executor for these arguments.
  // ここで、すでに起動しているデバイス及び新規に起動するデバイスリストを取得する
  ExecutorsAndKeys* executors_and_keys;
  Status s = GetOrCreateExecutors(input_tensor_names, output_names,
                                  target_nodes, &executors_and_keys);
  if (!s.ok()) {
    return s;
  }

  IntraProcessRendezvous* rendez =
      new IntraProcessRendezvous(device_mgr_.get());
  core::ScopedUnref rendez_unref(rendez);

  // Insert the input tensors into the local rendezvous by their
  // rendezvous key.
  for (const auto& input : inputs) {
    const string& input_key = executors_and_keys->input_keys[input.first];
    s = rendez->Send(input_key, Rendezvous::Args(), input.second, false);
    if (!s.ok()) {
      rendez->StartAbort(s);
      return s;
    }
  }

  // Start parallel Executors.
  Notification executors_done;
  const int num_executors = executors_and_keys->items.size();
  ExecutorBarrier* barrier = new ExecutorBarrier(
      num_executors, rendez, [&executors_done, &s](const Status& ret) {
        s = ret;
        executors_done.Notify();
      });

  Executor::Args args;
  args.rendezvous = rendez;
  args.cancellation_manager = cancellation_manager_;
  args.runner = [this](Executor::Args::Closure c) { SchedClosure(c); };

  // ここで実行処理をしている
  // item.executorの内容は、後で出てくるので後に整理します。
  // item.executorは、ExecutorImplクラスのインスタンスポインタである。
  for (const auto& item : executors_and_keys->items) {
    item.executor->RunAsync(args, barrier->Get());
  }

  // スレッドが終わるまで同期する。
  executors_done.WaitForNotification();

  TF_RETURN_IF_ERROR(s);

  if (!output_names.empty()) {
    outputs->resize(output_names.size());
  }

  // Get the outputs from the rendezvous
  for (size_t output_offset = 0; output_offset < output_names.size();
       ++output_offset) {
    const string& output_key =
        executors_and_keys->output_keys[output_names[output_offset]];
    Tensor output_tensor;
    bool is_dead;

    // Fetch data from the Rendezvous.
    s = rendez->Recv(output_key, Rendezvous::Args(), &output_tensor, &is_dead);
    if (is_dead) {
      s = errors::InvalidArgument("The tensor returned for ",
                                  output_names[output_offset],
                                  " was not valid.");
    }
    if (!s.ok()) {
      rendez->StartAbort(s);
      outputs->clear();
      return s;
    }

    (*outputs)[output_offset] = output_tensor;
  }

  return s;
}

Status DirectSession::GetOrCreateExecutors(
    gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
    gtl::ArraySlice<string> target_nodes,
    ExecutorsAndKeys** executors_and_keys) {
  // Sort the inputs and outputs, so we don't create separate
  // executors when a user passes in the same inputs/outputs in
  // different orders.
  //
  // We could consider some other signature instead of sorting that
  // preserves the same property to avoid the sort in the future.
  std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
  std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
  std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
  std::sort(inputs_sorted.begin(), inputs_sorted.end());
  std::sort(outputs_sorted.begin(), outputs_sorted.end());
  std::sort(tn_sorted.begin(), tn_sorted.end());

  const string key = strings::StrCat(str_util::Join(inputs_sorted, ","), "->",
                                     str_util::Join(outputs_sorted, ","), "/",
                                     str_util::Join(tn_sorted, ","));

  // See if we already have the executors for this run.
  {
    mutex_lock l(executor_lock_);  // could use reader lock
    auto it = executors_.find(key);
    if (it != executors_.end()) {
      *executors_and_keys = it->second;
      return Status::OK();
    }
  }

  // The executor_lock_ is intentionally released while executor is
  // being created.
  FunctionLibraryDefinition* fdefs;
  std::unordered_map<string, Graph*> graphs;
  Status s = CreateGraphs(inputs, outputs, target_nodes, &fdefs, &graphs);
  if (!s.ok()) {
    return s;
  }

  bool has_control_flow = false;
  for (const auto& graph : graphs) {
    for (const Node* n : graph.second->nodes()) {
      if (IsControlFlow(n)) {
        has_control_flow = true;
        break;
      }
    }
    if (has_control_flow) break;
  }

  std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
  ek->func_defs = fdefs;
  ek->items.reserve(graphs.size());
  auto runner = [this](Executor::Args::Closure c) { SchedClosure(c); };
  for (const auto& graph : graphs) {
    const string& partition_name = graph.first;
    Graph* partition_graph = graph.second;
    const int graph_def_version = partition_graph->version();

    Device* device;
    s = device_mgr_->LookupDevice(partition_name, &device);
    if (!s.ok()) {
      return s;
    }

    ek->items.resize(ek->items.size() + 1);
    auto* item = &(ek->items.back());
    // 関数ポインタを登録します。
    // カーネル作成用の関数などのポインタを設置します。
    item->flib =
        NewFunctionLibraryRuntime(device, runner, graph_def_version, fdefs);

    LocalExecutorParams params;
    params.has_control_flow = has_control_flow;
    params.device = device;
    params.function_library = item->flib;
    auto lib = item->flib;
    auto opseg = device->op_segment();

    // カーネル作成関数です。
    // create_kernelを関数の様に書くことで関数としてコールできます。
    // OpKernel* pKernel;
    // params.create_kernel(pKernel);
    params.create_kernel = [this, lib, opseg](const NodeDef& ndef,
                                              OpKernel** kernel) {
      auto create_fn = [lib, &ndef](OpKernel** kernel) {
        return lib->CreateKernel(ndef, kernel);
      };
      // Kernels created for subgraph nodes need to be cached.  On
      // cache miss, create_fn() is invoked to create a kernel based
      // on the function library here + global op registry.
      return opseg->FindOrCreate(session_handle_, ndef.name(), kernel,
                                 create_fn);
    };
    // カーネル削除関数です。 (何もやっていないようですね…)
    // create_kernelを関数の様に書くことで関数としてコールできます。
    params.delete_kernel = [](OpKernel* kernel) {
      // Do nothing because 'kernel' is owned by opseg above.
    };

    // 実行するためのクラスをitem->executorに作成する
    s = NewLocalExecutor(params, partition_graph, &item->executor);
    if (!s.ok()) {
      return s;
    }
  }

  // 今回は、重要ではないので省略

  return Status::OK();
}

2.3 カーネルの作成・初期化
カーネルの初期化をこの中でやっているため、ここからNewLocalExecutor内部に移動します。
new演算子を使用して、インスタンスを生成した後のInitilizeの中でカーネルを作成します。
Source URL: executor.cc
Line: 2130 - 2140

executor.cc
Status NewLocalExecutor(const LocalExecutorParams& params, const Graph* graph,
                        Executor** executor) {
  ExecutorImpl* impl = new ExecutorImpl(params, graph);
  // この中でカーネルクラスを作成します。
  Status s = impl->Initialize();
  if (s.ok()) {
    *executor = impl;
  } else {
    delete impl;
  }
  return s;
}

ExecutorImpl::Initialize()の中を見ていきましょう。
Line: 244 - 280

executor.cc
Status ExecutorImpl::Initialize() {
  const int num_nodes = graph_->num_node_ids();
  nodes_.resize(num_nodes);

  Status s;
  total_tensors_ = 0;

  // Preprocess every node in the graph to create an instance of op
  // kernel for each node;
  for (const Node* n : graph_->nodes()) {
    const int id = n->id();
    NodeItem* item = &nodes_[id];
    item->node = n;
    item->input_start = total_tensors_;
    total_tensors_ += n->num_inputs();
    // ここでカーネルを作成します。
    s = params_.create_kernel(n->def(), &item->kernel);
    if (!s.ok()) {
      // なにやらノードのアタッチをしているようですが、今回は関係ないため未確認です。
      s = AttachDef(s, n->def());
      LOG(ERROR) << "Executor failed to create kernel. " << s;
      break;
    }
    CHECK(item->kernel);

    // Initialize static information about the frames in the graph.
    if (IsEnter(n)) {
      string frame_name;
      s = GetNodeAttr(n->def(), "frame_name", &frame_name);
      if (!s.ok()) return s;
      ++frame_input_count_[frame_name];
    }
  }
  if (params_.has_control_flow) {
    VLOG(2) << "Graph has control flow.";
  }
  if (!s.ok()) return s;
  return SetAllocAttrs();
}

2.4 カーネルの新規作成
params_.create_kernel(n->def(), &item->kernel);
が、どこに繋がるかと言うと、
Source URL: direct_session.cc
Line: 380 -390

direct_session.cc
    params.create_kernel = [this, lib, opseg](const NodeDef& ndef,
                                              OpKernel** kernel) {
      // この中に移動します。

      // create_fnにCreateKernel関数をコールする関数ポインタをアタッチする。
      auto create_fn = [lib, &ndef](OpKernel** kernel) {
        return lib->CreateKernel(ndef, kernel);
      };
      // Kernels created for subgraph nodes need to be cached.  On
      // cache miss, create_fn() is invoked to create a kernel based
      // on the function library here + global op registry.

      // デバイスから取得したセグメントから起動済みのデバイスを検索または
      // CreateKernel関数をコールするかを実行する
      return opseg->FindOrCreate(session_handle_, ndef.name(), kernel,
                                 create_fn);
    };

2.5 カーネルの検索及び新規作成
OpSegment::FindOrCreateに移動します。
すでに起動している場合は、作成してしているカーネルのポインタを返します。
初回起動等で、まだ作成していないケースはCreateKernel関数をコールすることになります。
Source URL: op_segment.cc
Line: 36 - 70

op_segment.cc
Status OpSegment::FindOrCreate(const string& session_handle,
                               const string& node_name, OpKernel** kernel,
                               CreateKernelFn create_fn) {
  // スレッドのデッドロック対策のためのミューテックス - ON
  {
    mutex_lock l(mu_);
    auto item = gtl::FindPtrOrNull(sessions_, session_handle);
    if (item == nullptr) {
      return errors::NotFound("Session ", session_handle, " is not found.");
    }
    // SymbolicGradientHelper経由で作成されたNodeではなく、
    // カーネル名とノード名で検索して、マッチした場合はカーネルのポインタを使用する。
    *kernel = gtl::FindPtrOrNull(item->name_kernel, node_name);
    if (*kernel != nullptr) {
      return Status::OK();
    }
  }

  // 検索して見つからなかったので、CreateKernel関数をコールする
  // この中身は、「lib->CreateKernel(ndef, kernel)」になります。
  // libは、FunctionLibraryRuntimeImplなので、そちらに移動します。
  // 理由は、NewFunctionLibraryRuntimeを参照して下さい。
  Status s = create_fn(kernel);
  if (!s.ok()) {
    LOG(ERROR) << "Create kernel failed: " << s;
    return s;
  }

  // スレッドのデッドロック対策のためのミューテックス - OFF
  {
    mutex_lock l(mu_);
    auto item = gtl::FindPtrOrNull(sessions_, session_handle);
    if (item == nullptr) {
      return errors::NotFound("Session ", session_handle, " is not found.");
    }
    OpKernel** p_kernel = &(item->name_kernel[node_name]);
    if (*p_kernel == nullptr) {
      *p_kernel = *kernel;  // Inserts 'kernel' in the map.
    } else {
      delete *kernel;
      *kernel = *p_kernel;
    }
  }
  return Status::OK();
}

補足

gtl::FindPtrOrNullは、以下のヘッダーで定義されています。
Source URL: map_util.h

NewFunctionLibraryRuntime関数は、以下のソースファイルで定義されています。
Source URL: function.cc

FunctionLibraryRuntimeImpl::CreateKernelを見ていきましょう。
Source URL: Source URL: function.cc
Line: 374 - 401

function.cc
Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
                                                OpKernel** kernel) {
  // lib_def_は、DirectSession::CreateGraphs関数で作成したライブラリの関数ポインタ群と思われるが、
  // 詳細は、複雑になりそうなので別の機会とする。
  // ndefは、Node関連だろうがコチラも詳細は別の機会に解析する。
  if (ndef.op() != kGradientOp && (lib_def_->Find(ndef.op()) == nullptr)) {
    return CreateNonCachedKernel(device_, this, ndef, graph_def_version_,
                                 kernel);
  }

  // Try to instantiate this function for the func/attr. Maybe its
  // cached already.
  // 関数と属性からの関数のインスタンス化を試みる
  // 詳細な仕組みは、次の機会に実施する
  Handle handle;
  TF_RETURN_IF_ERROR(Instantiate(ndef.op(), ndef.attr(), &handle));

  const FunctionBody* fbody = GetFunctionBody(handle);
  CHECK_NOTNULL(fbody);

  // Constructs a CallOp kernel for running the instantiated function.
  // CallOp(Call Operationのこと?)カーネルをインスタンス化し起動する。
  // この辺りも、次の機会とする。
  Status s;
  auto device_type = DeviceType(device_->attributes().device_type());
  OpKernelConstruction construction(
      device_type, device_, device_->GetAllocator(AllocatorAttributes()), &ndef,
      &fbody->fdef.signature(), this, fbody->arg_types, fbody->ret_types,
      graph_def_version_, &s);
  *kernel = new CallOp(handle, &construction);
  if (!s.ok()) {
    delete kernel;
  }
  return s;
}

2.6 Executorを実行
2.2まで戻って、2.3で作成したExecutorを非同期で実行します。
その流れを見ていきましょう。
Source URL: direct_session.cc
Line: 208 - 427

direct_session.cc
Status DirectSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
                          const std::vector<string>& output_names,
                          const std::vector<string>& target_nodes,
                          std::vector<Tensor>* outputs) {
  {
    mutex_lock l(graph_def_lock_);
    if (!graph_created_) {
      return errors::InvalidArgument(
          "Session was not created with a graph before Run()!");
    }
  }

  // Extract the inputs names for this run of the session.
  std::vector<string> input_tensor_names;
  input_tensor_names.reserve(inputs.size());
  for (const auto& it : inputs) {
    input_tensor_names.push_back(it.first);
  }

  // Check if we already have an executor for these arguments.
  ExecutorsAndKeys* executors_and_keys;
  Status s = GetOrCreateExecutors(input_tensor_names, output_names,
                                  target_nodes, &executors_and_keys);
  if (!s.ok()) {
    return s;
  }

  // -----------------------ここまで 2.2 ----------------------------

  IntraProcessRendezvous* rendez =
      new IntraProcessRendezvous(device_mgr_.get());
  core::ScopedUnref rendez_unref(rendez);

  // Insert the input tensors into the local rendezvous by their
  // rendezvous key.
  for (const auto& input : inputs) {
    const string& input_key = executors_and_keys->input_keys[input.first];
    s = rendez->Send(input_key, Rendezvous::Args(), input.second, false);
    if (!s.ok()) {
      rendez->StartAbort(s);
      return s;
    }
  }

  // Start parallel Executors.
  Notification executors_done;
  const int num_executors = executors_and_keys->items.size();
  ExecutorBarrier* barrier = new ExecutorBarrier(
      num_executors, rendez, [&executors_done, &s](const Status& ret) {
        s = ret;
        executors_done.Notify();
      });

  Executor::Args args;
  args.rendezvous = rendez;
  args.cancellation_manager = cancellation_manager_;
  args.runner = [this](Executor::Args::Closure c) { SchedClosure(c); };

  // ポイントは、ここです。
  // ここで、executorを非同期実行します。
  for (const auto& item : executors_and_keys->items) {
    item.executor->RunAsync(args, barrier->Get());
  }

  executors_done.WaitForNotification();

  // 今回は、無用なので省略

  return s;
}

では、item.executor->RunAsyncの中を追いましょう。
ExecutorImpl::RunAsyncに繋がっていますので、こちらのコードに移動しましょう。

Source URL: executor.cc
Line: 2118 - 2125

executor.cc
// NOTE(yuanbyu): Use the executor that supports control flow by default.
const bool use_control_flow_executor = true;
void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
  // ControlFlowというオプション設定でしょうか?
  // 全体見てないので、まだ不明です。 (分かり次第追記します。)
  if (params_.has_control_flow || use_control_flow_executor) {
    (new ExecutorState(args, this))->RunAsync(done);
  } else {
    // 今回は、こちらのシーケンスを追っていきましょう。
    (new SimpleExecutorState(args, this))->RunAsync(done);
  }
}

SimpleExecutorState::RunAsyncに移動します。
Line: 1794 - 1819

executor.cc
void SimpleExecutorState::RunAsync(Executor::DoneCallback done) {
  const Graph* graph = impl_->graph_;
  ReadyNodeIds ready;

  // Ask the device to fill in the device context map.
  Device* device = impl_->params_.device;
  device->FillContextMap(graph, &device_context_map_);

  // 中身まで見切れていないので推測ですが、ノードにリンクされている画像を解析して
  // 中にエッヂがあるかカウントしていると思われます。
  // エッヂが無い画像が一つでもあれば処理を続行する。 (この場合の処理とは、ExecutorBarrierのこと)
  // エッヂがすべてある場合は、ScheduleReadyに入るようになっている。
  // ここでのエッヂの定義は、画像に関連付いているノードの数と思われる。
 // 理由は、補足に記載する。
  for (const Node* n : graph->nodes()) {
    const int id = n->id();
    const int num_in_edges = n->in_edges().size();
    pending_[id].Set(num_in_edges);
    if (num_in_edges == 0) {
      ready.push_back(id);
    }
  }
  if (ready.empty()) {
    done(Status::OK());
  } else {
    num_active_ = ready.size();
    done_cb_ = done;
    input_tensors_.resize(impl_->total_tensors_);
    // Schedule to run all the ready ops in thread pool.
    ScheduleReady(ready, nullptr);
  }
}

補足

以下のソースに定義されている、Graph::AddEdge関数にて挿入処理しています。
挿入のタイミングがNodeの増減に伴っていることから, 画像とリンクしているノードの数と推測しました。
Source URL: graph.cc

ExecutorBarrierの定義は、以下のファイルです。
Source URL: executor.h

3. 結論

現在は、スレッド作成するルートは使用されておらずデバッグ中らしい…
長くなってしまったので、一旦ここで区切ります。
ぐちゃぐちゃですみません。 :fearful:

次回は、『OpKernelContextを作成』からGPUの起動・アタッチまでの流れをソースと照らし合わせつつ見ていきます。

初めて書いたので分かりづらいと思います。
意見や要望等ありましたら、遠慮なく下さい。
お待ちしています。