#はじめに
PyTorchのCUDAプログラミングに絞って並列処理を見てみる。なお、CPU側の並列処理は別資料に記載済みである。ここでは、
- C++の拡張仕様であるCUDAの基礎知識
- カーネルレベルの並列処理
- add関数の実装
- im2col関数の実装
- ストリームレベルの並列処理
- DistributedDataParallelの呼び出し処理の実装
について説明する。
おことわり
- PyTorchでは、ATen配下で演算処理を行っている。しかし、その前身であるTorchの資産を引き継いでいるため、THC(TorcH Cuda)のTensorからATen/nativeに書き換え中である。古いTHCTensor部分は参考資料として引用しておくが、言及はしない。また、この書き換え作業は1年以上継続中でありゆっくりと進んでいる。あと一年以上はかかるのではと思う。
- BLAS(行列演算)やcudnn(深層学習演算)は、ライブラリで定義された関数を呼ぶ形式である。関数内で並列処理されるため、プログラマから見て明示的な要素演算はない。そのため、ここでは言及しない。
#CUDAの基礎知識
CUDAレベルの並列処理は、2階層になっている。
- カーネルレベル
- ストリームレベル
PyTorchでは、単一デバイスの場合、カーネルレベルの処理を行う。その際、ストリームは一つ使う。しかし、複数デバイス(GPGPU)では、複数ストリームを用いて実行する。本稿では、単一デバイスの際のカーネルレベルの並列処理から始め、複数デバイスのストリームレベルの並列処理へと進む。
##カーネルレベルの並列処理
CUDAでは、GPGPUで動かすための記法が追加されている。ここでは、カーネルの起動パラメータ、カーネル内でのスレッド番号、およびGPGPUでの動作指定について簡単に紹介する。
<<<Dg, Db, Ns, S>>>
について
CUDAのソフトウェア階層は、grid, block, threadになっている。ハードウェアの階層では、GPGPUデバイス、Streaming Multiprocessor、CUDAコアにそれぞれ相当する。これを、CUDAカーネルの起動の際に<<<Dg, Db>>>
で引き渡す。Dbは、Warpを考慮し32の倍数が必須であり、1024が最大値である。
更に<<< >>>
の後ろ(3、4項目目)を指定する場合、grid, blockに加えて、メモリの共有サイズ、ストリームを指定する。
粒度 | ソフトウェア | ハードウェア |
---|---|---|
大 | grid | GPGPU |
中 | block | Stream Multiprocessor (SM) |
小 | thread | CUDA core |
なお、PyTorchの場合、grid, blockを個別に設定する場合もあるがgetApplyGrid
やgetApplyBlock
という関数を使うこともある。
###CUDA組み込み変数
上記で起動した関数は、以下の変数によりスレッド番号等を取得し処理を行う。
blockIdx
, threadIdx
, blockDim
, gridDim
, warpSize
###関数の動く場所の宣言
__global__
等により、デバイス側(GPGPU)で動かす等の設定を行う。
##ストリームレベルの並列処理
PyTorchで使うストリーム処理は大まかに、生成、同期、状態取得の3つが使われる。そして、デバイス(GPGPU)ごとにストリームが設定される。
- ストリームの生成
- ストリームの同期
- ストリームの状態取得
##その他
CUDAの環境変数等は、Programming Manualに記載されている。このため、ある程度大きな粒度の設定の場合、環境変数を設定したほうが良い。
実装例(addを見てみる)
ATen層内で、dispatchからは、以下の関数が呼び出される。
REGISTER_DISPATCH(add_stub, &add_kernel_cuda);
ATenで、addは以下のように定義している。
void add_kernel_cuda(TensorIterator& iter, Scalar alpha_scalar) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.dtype(), "add_cuda/sub_cuda", [&]() {
auto alpha = alpha_scalar.to<scalar_t>();
gpu_kernel_with_scalars(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
return a + alpha * b;
});
});
}
ここで、AT_DISPATCH_ALL_TYPES_AND2
は、複数のテンソルデータ種類(HalfTensor等)を定義するマクロである。そして、gpu_kernel_with_scalars
以下が実際にCUDAで処理を行う。gpu_kernel_with_scalars
からgpu_kernel
、gpu_kernel_impl
そして、launch_kernel
を経由してelementwise_kernel
の順に呼び出して、処理を行う。
まず、gpu_kernel_impl
にて、テンソルの要素ごとの演算に落とし込む。Tensorのデータを保持するiter
からdata
やstrides
への変換を行う。そして、launch_kernel
へ引き渡す。ここで、grid
およびblock
の計算前の数そして、ラムダ式による演算(addなので足し算)が定義されている。
template <typename func_t>
void gpu_kernel_impl(TensorIterator& iter, const func_t& f) {
using traits = function_traits<func_t>;
using arg0_t = typename traits::result_type;
constexpr int ntensors = traits::arity + 1;
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
TORCH_INTERNAL_ASSERT(iter.ntensors() == traits::arity + 1);
at::detail::Array<char*, ntensors> data;
for (int i = 0; i < ntensors; i++) {
data[i] = (char*)iter.data_ptr(i);
}
int64_t numel = iter.numel();
if (iter.is_trivial_1d()) {
auto inner_strides = iter.get_inner_strides();
at::detail::Array<int, ntensors> strides;
for (int i = 0; i < ntensors; i++) {
strides[i] = inner_strides[i];
}
launch_kernel<launch_size_1d, 1>(numel, [=]GPU_LAMBDA(int idx) {
arg0_t* out = (arg0_t*)(data[0] + strides[0] * idx);
*out = invoke(f, &data.data[1], &strides.data[1], idx);
});
} else {
auto offset_calc = make_offset_calculator<traits::arity + 1>(iter);
launch_kernel<launch_size_nd, launch_bound2>(numel, [=]GPU_LAMBDA(int idx) {
auto offsets = offset_calc.get(idx);
arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
*out = invoke(f, &data.data[1], &offsets.data[1], 1);
});
}
}
上記で、launch_kernel
の< >
がgrid, blockの入力値になる。そして値は、以下のように定義されている。このため、CUDAの場合は、block
に対して512もしくは128が設定される。参考までに、HIPは、AMD用の設定である。
#ifdef __HIP_PLATFORM_HCC__
static constexpr int launch_size_1d = 1024;
static constexpr int launch_size_nd = 1024;
static constexpr int launch_bound2 = 1;
#else
static constexpr int launch_size_1d = 512;
static constexpr int launch_size_nd = 128;
static constexpr int launch_bound2 = 4;
#endif
CUDAのカーネルは、launch_kernel
から呼び出す。elementwise_kernel
は、CUDAのカーネル独自の記法<<< >>>
で記載されている。ここで、nt
, vt
が上から引き渡された変数であり、block
およびgrid
として値が設定される。なお、CUDAのデータ形式dim3
なのでblock.x
として変数へのアクセスを行う。
template<int nt, int vt, typename func_t>
static void launch_kernel(int64_t N, const func_t& f) {
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
if (N == 0) {
return;
}
dim3 block(nt);
dim3 grid((N + block.x * vt - 1) / (block.x * vt));
auto stream = at::cuda::getCurrentCUDAStream();
elementwise_kernel<nt, vt, func_t><<<grid, block, 0, stream>>>(N, f);
AT_CUDA_CHECK(cudaGetLastError());
}
elementwise_kernel
では、スレッド毎に、並列処理を行う。
template<int nt, int vt, typename func_t>
C10_LAUNCH_BOUNDS_2(nt, launch_bound2)
__global__ void elementwise_kernel(int N, func_t f) {
int tid = threadIdx.x;
int nv = nt * vt;
int idx = nv * blockIdx.x + tid;
#pragma unroll
for (int i = 0; i < vt; i++) {
if (idx < N) {
f(idx);
idx += nt;
}
}
}
#実装例(im2colを見てみる)
PyTorchのPython層では、fold/unfoldメソッドという畳込みの前処理関数がある。この関数は、C++(ATen層)ではim2col
という名前である。im2col
は、画像を畳み込み層の計算を効率的に行うための仕組みである。この演算は2006年ごろから、Matlabでim2col
という名称で使われてきたが、2014年ごろから深層学習フレームワークのCaffeでも使われるようになった。そして、PyTorchでは、Caffeのim2col関数を参考にして実装している。
PyTorchでim2col
は、native_functions.yaml
で定義され、CUDAデバイスの場合im2col_cuda
で処理している。それから、im2col_out_cuda_template
、im2col
、im2col_kernel
の順に呼び出している。
まず、im2col_cuda
から見ていく。
Tensor im2col_cuda(
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef dilation,
IntArrayRef padding,
IntArrayRef stride) {
Tensor output = at::empty_like(input);
im2col_out_cuda_template(
output, input, kernel_size, dilation, padding, stride);
return output;
}
im2col_out_cuda_template
でCUDAのストリームを設定する。
static void im2col_out_cuda_template(
Tensor& output,
const Tensor& input_,
IntArrayRef kernel_size,
IntArrayRef dilation,
IntArrayRef padding,
IntArrayRef stride) {
TORCH_CHECK(
kernel_size.size() == 2,
"It is expected kernel_size equals to 2, but got size ",
kernel_size.size());
TORCH_CHECK(
dilation.size() == 2,
"It is expected dilation equals to 2, but got size ",
dilation.size());
TORCH_CHECK(
padding.size() == 2,
"It is expected padding equals to 2, but got size ",
padding.size());
TORCH_CHECK(
stride.size() == 2,
"It is expected stride equals to 2, but got size ",
stride.size());
int64_t kernel_height = kernel_size[0];
int64_t kernel_width = kernel_size[1];
int64_t dilation_height = dilation[0];
int64_t dilation_width = dilation[1];
int64_t pad_height = padding[0];
int64_t pad_width = padding[1];
int64_t stride_height = stride[0];
int64_t stride_width = stride[1];
TensorArg input_arg{input_, "input", 1};
TensorArg output_arg{output, "output", 2};
checkAllSameGPU("im2col_cuda", {input_arg, output_arg});
im2col_shape_check(
input_,
Tensor(),
kernel_height,
kernel_width,
dilation_height,
dilation_width,
pad_height,
pad_width,
stride_height,
stride_width);
Tensor input = input_.contiguous();
bool batched_input = true;
if (input.dim() == 3) {
batched_input = false;
input.resize_({1, input.size(0), input.size(1), input.size(2)});
}
int64_t batch_size = input.size(0);
int64_t n_input_plane = input.size(1);
int64_t input_height = input.size(2);
int64_t input_width = input.size(3);
int64_t output_height = (input_height + 2 * pad_height -
(dilation_height * (kernel_height - 1) + 1)) /
stride_height +
1;
int64_t output_width = (input_width + 2 * pad_width -
(dilation_width * (kernel_width - 1) + 1)) /
stride_width +
1;
int64_t n_output_plane = n_input_plane * kernel_width * kernel_height;
int64_t output_length = output_height * output_width;
output.resize_({batch_size, n_output_plane, output_length});
output.zero_();
// Launch kernel
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "im2col_out_cuda", [&] {
Tensor input_n;
Tensor output_n;
for (int64_t elt = 0; elt < batch_size; elt++) {
input_n = input.select(0, elt);
output_n = output.select(0, elt);
im2col<scalar_t>(
at::cuda::getCurrentCUDAStream(),
input_n.data_ptr<scalar_t>(),
n_input_plane,
input_height,
input_width,
output_height,
output_width,
kernel_height,
kernel_width,
pad_height,
pad_width,
stride_height,
stride_width,
dilation_height,
dilation_width,
output_n.data_ptr<scalar_t>());
}
if (!batched_input) {
output.resize_({n_output_plane, output_length});
}
});
}
im2col
から、GPUカーネルを起動する。CUDAカーネル定番の定義<<< >>>
が行われている。ここでは、SMのスレッドの最大数1024が設定されている。
template <typename dt>
void im2col(
cudaStream_t stream,
const dt* data_im,
const int64_t channels,
const int64_t height,
const int64_t width,
const int64_t height_col,
const int64_t width_col,
const int64_t kernel_height,
const int64_t kernel_width,
const int64_t pad_height,
const int64_t pad_width,
const int64_t stride_height,
const int64_t stride_width,
const int64_t dilation_height,
const int64_t dilation_width,
dt* data_col) {
// We are going to launch channels * height_col * width_col kernels, each
// kernel responsible for copying a single-channel grid.
int64_t num_kernels = channels * height_col * width_col;
// Launch CUDA_NUM_THREADS = 1024
im2col_kernel<<<GET_BLOCKS(num_kernels), 1024, 0, stream>>>(
num_kernels,
data_im,
height,
width,
kernel_height,
kernel_width,
pad_height,
pad_width,
stride_height,
stride_width,
dilation_height,
dilation_width,
height_col,
width_col,
data_col);
AT_CUDA_CHECK(cudaGetLastError());
}
im2col_kernel
が、GPGPUで動く関数である。
template <typename dt>
C10_LAUNCH_BOUNDS_1(1024)
__global__ void im2col_kernel(
const int64_t n,
const dt* data_im,
const int64_t height,
const int64_t width,
const int64_t kernel_height,
const int64_t kernel_width,
const int64_t pad_height,
const int64_t pad_width,
const int64_t stride_height,
const int64_t stride_width,
const int64_t dilation_height,
const int64_t dilation_width,
const int64_t height_col,
const int64_t width_col,
dt* data_col) {
CUDA_KERNEL_LOOP(index, n) {
int64_t w_out = index % width_col;
index /= width_col;
int64_t h_out = index % height_col;
int64_t channel_in = index / height_col;
int64_t channel_out = channel_in * kernel_height * kernel_width;
int64_t h_in = h_out * stride_height - pad_height;
int64_t w_in = w_out * stride_width - pad_width;
data_col += (channel_out * height_col + h_out) * width_col + w_out;
data_im += (channel_in * height + h_in) * width + w_in;
for (int64_t i = 0; i < kernel_height; ++i) {
for (int64_t j = 0; j < kernel_width; ++j) {
int64_t h = h_in + i * dilation_height;
int64_t w = w_in + j * dilation_width;
*data_col = (h >= 0 && w >= 0 && h < height && w < width)
? data_im[i * dilation_height * width + j * dilation_width]
: ScalarConvert<int, dt>::to(0);
data_col += height_col * width_col;
}
}
}
}
ここで、CUDA_KERNEL_LOOP
は、ループを定義するマクロである。ここで、CUDA定番の<<< >>>
が使われている。
#define CUDA_KERNEL_LOOP(i, n) \
int64_t _i_n_d_e_x = blockIdx.x * blockDim.x + threadIdx.x; \
for (int i=_i_n_d_e_x; _i_n_d_e_x < (n); _i_n_d_e_x+=blockDim.x * gridDim.x, i=_i_n_d_e_x)
実装例(DistributedDataParallelの処理を見てみる)
PyTorchに組み込んである並列処理は、DataParallel
とDistributedDataParallel
がある。出来る処理は以下の通りであり、マルチプロセスの処理はDistributedDataParallel
で行う必要がある。
クラス名 | シングルノードマルチデバイス | マルチノード |
---|---|---|
DataParallel | ○ | × |
DistributedDataParallel | ○ | ○ |
DistributedDataParallelの場合、分散処理の説明文書がある。そして。サンプルコードとしてはexamples/imagenet
がある。
DataParalellの場合、チュートリアルのMULTI-GPU EXAMPLESがある。
そして、Single Node multi GPUでは、以下のオプションで動かすことができる。たとえば、Tesla V100やTesla T4でも動かせる。なお、RTX2080Tiの場合、NVLinkで接続されている必要がある。PCIeバス接続でのマルチGPU接続動かない。
python main.py -a resnet50 --dist-url 'tcp://127.0.0.1:FREEPORT' --dist-backend 'nccl' --multiprocessing-distributed --world-size 1 --rank 0 [imagenet-folder with train and val folders]
PyTorchでの使い方
並列学習では、初期化として、
- モデルの定義の並列化
- データローダーの並列化定義
を行う。
###モデルの定義
モデルの定義は、以下のようにモデルを定義後、並列化クラスを呼び出して行う。
なお、DistributedDataParallel
を使う場合、該当関数を呼び出す前にinit_process_group
が必要である。
import torch.distributed as dist
if args.distributed:
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size)
if not args.distributed:
if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
model.features = torch.nn.DataParallel(model.features)
model.cuda()
else:
model = torch.nn.DataParallel(model).cuda()
else:
model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model)
###データローダの定義
データローダーの定義は以下で行う。ここでtrain_sampler
によりデータ分散が行われている。
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else:
train_sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
adjust_learning_rate(optimizer, epoch)
# train for one epoch
train(train_loader, model, criterion, optimizer, epoch)
##フレームワークの処理
Python層の処理
並列モデルの定義は、torch.nn.parallel
で行う。そして、その基盤となる関数は、
- コミュニケーション関数群
torch.cuda.comm
- デバイス管理ユーティリティ
torch.cuda._utils
- ストリーム管理クラス
torch.cuda.Stream
である。
DistributedDataParallel
では、以下の順で処理をする。これは、imagenet等のサンプルコードを参照のこと。
torch.distributed.init_process_group
DistributedDataParalell
torch.distributed.init_process_group
は、最終的にProcessGroupXXXX
を呼び出して、NCCL, Gloo等の設定をする。ただし、C++層の話なので後程説明する。
def _new_process_group_helper(world_size,
rank,
group_ranks,
backend,
store,
group_name=None,
timeout=_default_pg_timeout):
"""
Create a new distributed process group.
This function must be called by ALL processes in the global group, even if
the calling process is not part of the newly created group. In that case,
this function returns GroupMember.NON_GROUP_MEMBER.
This function is called with ``group_ranks == []`` for the default group.
"""
global _pg_map
global _group_count
global _pg_names
if not group_name:
group_name = str(_group_count)
_group_count += 1
if group_name in _pg_names.values():
raise RuntimeError("The specified group name has already been "
"created, please use a different group name")
if not isinstance(timeout, timedelta):
raise RuntimeError("Expected timeout argument to be of type"
"datetime.timedelta")
# The list of group ranks is empty if we're creating the default group.
is_default_group = (len(group_ranks) == 0)
backend = Backend(backend)
if backend == Backend.MPI:
if not is_mpi_available():
raise RuntimeError("Distributed package doesn't have MPI built in")
pg = ProcessGroupMPI.create(group_ranks)
if not pg:
return GroupMember.NON_GROUP_MEMBER
_pg_map[pg] = (Backend.MPI, None)
_pg_names[pg] = group_name
else:
# If this is a subgroup (which means group_ranks is specified),
# we check if the current process is a member of the new group.
if not is_default_group:
global_rank = _default_pg.rank()
if global_rank not in group_ranks:
return GroupMember.NON_GROUP_MEMBER
# Use the group name as prefix in the default store, such that
# a single store can be reused by multiple groups.
prefix_store = PrefixStore(group_name, store)
if backend == Backend.GLOO:
pg = ProcessGroupGloo(
prefix_store,
rank,
world_size,
timeout=timeout)
_pg_map[pg] = (Backend.GLOO, store)
_pg_names[pg] = group_name
elif backend == Backend.NCCL:
if not is_nccl_available():
raise RuntimeError("Distributed package doesn't have NCCL "
"built in")
pg = ProcessGroupNCCL(
prefix_store,
rank,
world_size)
_pg_map[pg] = (Backend.NCCL, store)
_pg_names[pg] = group_name
else:
raise RuntimeError("Unsupported distributed backend by group")
return pg
さて、DistributedDataParallel
内では(シングルノードマルチGPU処理では)
-
scatter
でデータをばら撒く - 処理を行う
すでに、DistributedDataParallelが処理を行う時点で、(シングルノードマルチGPU処理では)プロセスごとにデバイスが分けられているので、gather処理は行われない。
def forward(self, *inputs, **kwargs):
if self.require_forward_param_sync:
self._sync_params()
if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
output = self.module(*inputs[0], **kwargs[0])
else:
outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
else:
output = self.module(*inputs, **kwargs)
if torch.is_grad_enabled() and self.require_backward_grad_sync:
self.require_forward_param_sync = True
# We'll return the output object verbatim since it is a freeform
# object. We need to find any tensors in this object, though,
# because we need to figure out which parameters were used during
# this forward pass, to ensure we short circuit reduction for any
# unused parameters. Only if `find_unused_parameters` is set.
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
self.require_forward_param_sync = False
return output
scatter
では、デバイスごとにデータのばら撒きと、デバイスからストリームへの変換を行う。
参考までに、_get_stream
で、torch.cuda.Stream
を使って、PythonでデバイスIDからストリームIDへの変換を行っている。
def _get_stream(device):
"""Gets a background stream for copying between CPU and GPU"""
global _streams
if device == -1:
return None
if _streams is None:
_streams = [None] * torch.cuda.device_count()
if _streams[device] is None:
_streams[device] = torch.cuda.Stream(device)
return _streams[device]
C++層での処理
Python層の処理を受けて、CUDAライブラリの呼び出しが行われる(CUDA Runtime & NCCL)。
ここでは、ストリーム(Stream)の初期化及びその管理について見ていく。
####ストリームの初期化
PyTorchでは、torch.cuda.Stream
をPython層で定義する。そこから、C++層に潜りtorch._C._CudaStreamBase
を呼び出す。なお、getStreamFromPool
のinitDeviceStreamState
延長でcudaStream
の初期化が行われる。このストリームは、低プライオリティ(priority=0
)で使われる。
####複数デバイス間通信処理
Python層(torch.cuda.comm
)から呼び出されるProcessGroupNCCL
の初期化を行う。一連の作業で、 ncclCommWatchdogThread_
スレッドを起動する。また、WorkNCCLという名前空間があるが、それもinit.cpp
で定義される。
また、CUDAのメソッド(scatter/gather等)は、起動時のPython-Cの接続で定義される。モジュールtorch._C
は、以下で定義する。
#参考資料
##ソースコード
ATen/native
add関連
- aten/src/ATen/native/cuda/BinaryOpsKernel.cu
- aten/src/ATen/native/TensorIterator.h
-
aten/src/ATen/Dispatch.h
-
AT_DISPATCH_ALL_TYPES_AND
複数のデータ型に対応したラムダ式を作成するマクロ
-
- aten/src/ATen/native/cuda/Loops.cuh
im2col関連
- aten/src/ATen/native/cuda/Im2Col.cu
- aten/src/ATen/native/cuda/im2col.cuh
- aten/src/ATen/cuda/detail/KernelUtils.h
#### その他
THC
- レガシー(Legacy)なので以下の関数は使ってはいけない。参考資料としてみるだけ。
- aten/src/THC/THCApply.cuh
- aten/src/THC/THCReduce.cuh
ストリームレベル
Python
C++
C言語仕様
CUDA
入門
C拡張仕様
-
CUDA C Programming Guide
-
B. C Language Extensions
-
B.1. Function Execution Space Specifiers
-
__device__
,__host__
,__global__
-
-
B.3. Built-in Vector Types
-
uint3
,int2
-
-
B.4. Built-in Variables
-
blockIdx
,threadIdx
,blockDim
,gridDim
,warpSize
-
-
B.22. Execution Configuration
<<< >>>
-
B.1. Function Execution Space Specifiers
- H.1. Features and Technical Specifications
- J. CUDA Environment Variables
-
B. C Language Extensions
Runtime API
libcudart.soがCUDA Runtimeライブラリである。
-
NVIDIA CUDA Runtime API
-
1. Difference between the driver and runtime APIs
- runtimeライブラリとdriverライブラリの違い。PyTorchは、runtimeライブラリを使っている。参考までに、TensorFlowでは、driverライブラリを使っている。
- 5.9. Memory Management cudaMemcpy等の関数説明
-
1. Difference between the driver and runtime APIs
####NCCL API
####ハードウェアの資料
- NVIDIA Volta Arcihtecture Whitepaper
- NVIDIA Pascal Architecture Whitepaper
- NVIDIA Maxwell Architecture
- NVIDIA Kepler Architectue Whitepaper
- NVIDIA Fermi Architecture Whitepaper
C++の一般的な仕様
- Lambda expressions (since C++11) (ラムダ式 関数の一種)
- ブログ等
##数値演算