LoginSignup
41
21

More than 5 years have passed since last update.

マルチスレッドに関するTensorFlowの内部構造について

Last updated at Posted at 2016-12-11

TensorFlowのMNIST Tutorialを初めて動かした時にPCファンがごーっと回りだし、おやっとCPU利用率を確認してみたらほぼほぼ100%に張り付いていたのを見たのが、本題について調べようと思ったきっかけです。ユーザー側からのコンフィグなしにどのようにマルチスレッドをTensorFlowで実現しているのか調べてみました。

環境

以下の通り仮想マシン上でTensorFlowを動かしています。
- CPU: Intel Core i5-4570, 3.2GHz, 4 cores, Haswell (ark.intel.com)
- Ubuntu 16.04 on VMware Workstation 12
- 4仮想CPU、8GBメモリ
- Python 3.5.2
- GNU G++ 5.4.0
- master branch of tensorflow github

$ git log -1
commit 48fb73a1c94ee2409382225428063d3496dc651e
Merge: d5d654c f7ff20e
Author: gunan <gunan@google.com>
Date:   Fri Dec 2 16:30:10 2016 -0800

    Merge pull request #6047 from meteorcloudy/fix_cuda_configure

    Fix cuda version detect on Windows

評価

マルチスレッド調査の目的では機械学習のプログラムは複雑すぎるので次のような単純な行列積のプログラムで評価を行いました。$N \times N$ の行列同士の積です。tf.random_normalは標準正規分布から確率変数を生成してくれます。intra_op_parallelism_threadsが本調査の肝となる部分で、陽にスレッド数を設定することが可能です。

import tensorflow as tf

N = 10000

W1 = tf.random_normal((N, N))
W2 = tf.random_normal((N, N))
C = tf.matmul(W1, W2)

myconfig = tf.ConfigProto(
    intra_op_parallelism_threads=4 )

with tf.Session(config=myconfig) as sess:
    sess.run(tf.global_variables_initializer())
    sess.run([C])
    print("Done")

実際にスレッド数の効果を見てみましょう。5回試行した結果の経過時間の平均値となっています。

intra_op_parallelism_threads N=5000 (sec.) N=10000 (sec.)
2 7.1 46.4
4 4.2 25.1
Speed up x1.68 x1.84

ちなみにスレッド数4の時のvmstat 1の抜粋が次の通りです。

procs -----------memory---------- ---swap-- -----io---- -system-- ------cpu-----
 r  b   swpd   free   buff  cache   si   so    bi    bo   in   cs us sy id wa st
 4  0 195540 4388864 126060 1205452    0    0     0     0  999  113 100  0  0  0  0
 4  0 195540 4388864 126060 1205452    0    0     0     0 1020   92 100  0  0  0  0
 4  0 195540 4388864 126060 1205452    0    0     0     0 1034   96 100  0  0  0  0
 4  0 195540 4388864 126060 1205452    0    0     0     0  987   90 100  0  0  0  0

美しいですね。

内部構造

ではいよいよソースコードを見ていきます。先程のintra_op_parallelism_threadsは次の部分で初期化されます。

tensorflow/core/common_runtime/local_device.cc
struct LocalDevice::EigenThreadPoolInfo {
  EigenThreadPoolInfo(const SessionOptions& options) {
    int32 intra_op_parallelism_threads =
        options.config.intra_op_parallelism_threads();
    if (intra_op_parallelism_threads == 0) {
      intra_op_parallelism_threads = port::NumSchedulableCPUs();
    }

陽に設定されなければOSが用いているCPUコア数が自動的に設定されるようになっています。このため何もコンフィグしなくてもCPU利用率が上がるのですね。では続きを見てみましょう。

tensorflow/core/common_runtime/local_device.cc
    eigen_worker_threads_.workers = new thread::ThreadPool(
        options.env, "Eigen", intra_op_parallelism_threads);
    eigen_threadpool_wrapper_.reset(
        new EigenThreadPoolWrapper(eigen_worker_threads_.workers));
    eigen_device_.reset(new Eigen::ThreadPoolDevice(
        eigen_threadpool_wrapper_.get(), eigen_worker_threads_.num_threads));

  ~EigenThreadPoolInfo() {
    eigen_threadpool_wrapper_.reset();
    eigen_device_.reset();
    delete eigen_worker_threads_.workers;
  }

  DeviceBase::CpuWorkerThreads eigen_worker_threads_;
  std::unique_ptr<Eigen::ThreadPoolInterface> eigen_threadpool_wrapper_;
  std::unique_ptr<Eigen::ThreadPoolDevice> eigen_device_;

intra_op_parallelism_threads個で構成されるEigen::ThreadPoolDeviceというのが作成されました。EigenというのはC++ templateからなる線形代数ライブラリです。TensorFlowは自前で線形代数ライブラリを用意するのではなくEigenのtemplateを用いています。Eigenのスレッドプールを用いて行列積が並列化される、ということになります。実質これで今回の話は終わりなのですが、、行列積についてEigenの関数が呼ばれるところまで追いかけてみます。

tensorflow/core/kernels/matmul_op.cc
template <typename Device, typename T, bool USE_CUBLAS>
class MatMulOp : public OpKernel {
 public:
  explicit MatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
    OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
  }

  void Compute(OpKernelContext* ctx) override {
    const Tensor& a = ctx->input(0);
    const Tensor& b = ctx->input(1);
<snip>
    LaunchMatMul<Device, T, USE_CUBLAS>::launch(ctx, this, a, b, dim_pair, out);
  }

tensorflow/core/kernelsディレクトリ内でmatmulなどのオペレーションがたくさん実装されています。どれも非常に興味深いです。行列積(matmul)についてはclass MatMulOpで実装されていて、LaunchMatMul<...>::launch()を呼び出しています。

tensorflow/core/kernels/matmul_op.cc
// On CPUs, we ignore USE_CUBLAS
template <typename T>
struct LaunchMatMulCPU {
  static void launch(
      OpKernelContext* ctx, OpKernel* kernel, const Tensor& a, const Tensor& b,
      const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
      Tensor* out) {
    // An explicit vector-matrix multiply is much better optimized than an
    // implicit one and this is a bottleneck during non-batched inference.
    bool was_vector = ExplicitVectorMatrixOptimization<T>(a, b, dim_pair, out);
    if (!was_vector) {
      functor::MatMulFunctor<CPUDevice, T>()(ctx->eigen_device<CPUDevice>(),
                                             out->matrix<T>(), a.matrix<T>(),
                                             b.matrix<T>(), dim_pair);
    }
  }
};

template <typename T, bool USE_CUBLAS>
struct LaunchMatMul<CPUDevice, T, USE_CUBLAS> : public LaunchMatMulCPU<T> {};

functor::MatMulFunctor()を呼び出します。

tensorflow/core/kernels/matmul_op.cc
namespace functor {

// Partial specialization MatMulFunctor<Device=CPUDevice, T>.
template <typename T>
struct MatMulFunctor<CPUDevice, T> {
  void operator()(
      const CPUDevice& d, typename MatMulTypes<T>::out_type out,
      typename MatMulTypes<T>::in_type in0,
      typename MatMulTypes<T>::in_type in1,
      const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair) {
    MatMul<CPUDevice>(d, out, in0, in1, dim_pair);
  }
};

}  // end namespace functor

MatMul()は次のヘッダ内で定義されています。

tensorflow/core/kernels/matmul_op.h
template <typename Device, typename In0, typename In1, typename Out,
          typename DimPair>
void MatMul(const Device& d, Out out, In0 in0, In1 in1,
            const DimPair& dim_pair) {
  out.device(d) = in0.contract(in1, dim_pair);
}

最終的にcontract()という関数が呼び出され、この先はEigenの世界です。TensorContractionThreadPool.h が関連しています。ここが並列化のガチの実装部分なのですが、時間の都合上Eigenの内部についてはまた今度です。ただ、TensorFlowの並列化はEigenに強く依存しているのだな、ということを理解して頂ければ良いと思います。

実際に過去にはMultiple CPU usage ineffectiveというIssueが報告されました。ab02c5で修正されたというので見たところ、次のようにTensorFlowが利用するEigenのrevisionを変更するというpatchで、なるほどね、ってなります。

commit ab02c5ab2f1f10bc9e51f02f5125abed449cae87
Author: Benoit Steiner <benoit.steiner.goog@gmail.com>
Date:   Mon May 16 13:55:44 2016 -0800

    Switched to the latest version of Eigen that performs much better on machines
    with many cpu cores

    For example, the wall time for the following tutorial went down from 13m35 to 5m27:
    bazel run -c opt --copt=-mavx tensorflow/examples/tutorials/word2vec/word2vec_basic
    Change: 122462177

diff --git a/eigen.BUILD b/eigen.BUILD
index 49fe45b..a657493 100644
--- a/eigen.BUILD
+++ b/eigen.BUILD
@@ -1,6 +1,6 @@
 package(default_visibility = ["//visibility:public"])

-archive_dir = "eigen-eigen-aaa010b0dd40"
+archive_dir = "eigen-eigen-a5e9085a94e8"

 cc_library(
     name = "eigen",
diff --git a/tensorflow/contrib/cmake/external/eigen.cmake b/tensorflow/contrib/cmake/external/eigen.cmake
index 4dcc491..42fa768 100644
--- a/tensorflow/contrib/cmake/external/eigen.cmake
+++ b/tensorflow/contrib/cmake/external/eigen.cmake
@@ -7,7 +7,7 @@

 include (ExternalProject)

-set(eigen_archive_hash "aaa010b0dd40")
+set(eigen_archive_hash "a5e9085a94e8")

 set(eigen_INCLUDE_DIRS
     ${CMAKE_CURRENT_BINARY_DIR}

ベクトル化

高速化にとって重要なのは並列化だけでなくCPU命令の効率的な利用です。今回使っているCPUはHaswellですのでAVXやAVX2というベクトル命令があり256-bit長のYMMレジスタを利用可能です。あるなら使ってみましょう、ということでavx2オプションをつけてbuildしてみました。TensorFlowのbuild手順についてはこちらも参考にしてみてください。

$ bazel build -c opt --copt=-mavx2 //tensorflow/tools/pip_package:build_pip_package

intra_op_parallelism_threadsオプションは設定していません。つまりコア数の分だけ並列化されます。今回の環境では4となります。経過時間を比較したのが次の表です。EigenがAVX命令に対応しているためそれなりに効いていることが分かります。

N No AVX (sec.) AVX/AVX2 (sec.) Speed up
5000 4.2 2.9 x1.4
10000 25.1 15.6 x1.6

まとめ

TensorFlowがどのようにマルチスレッド化されているのか行列積(tf.matmul)について調査しました。EigenというC++ templateからなる数値演算ライブラリが、TensorFlowで初期化されたThreadPoolのリソースを用いて実際の並列化を行っていることを説明しました。また、ベクトル化の効果を示しました。

今後

Eigenについて調査を継続したいです。TensorFlowについては、行列積以外のオペレーションについての並列化状況の調査を進めたいです。また、CPUプラットフォームに応じた適切なコンパイラオプションの提供、ログレベルの適切な実装、といったところがTensorFlowの改善点として調査の中で見えてきたので試行錯誤してみたいです。

41
21
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
41
21