Tensor Coresは、複数の単純な計算を同時に実行できるユニットで、特にAIで必要とされる密行列の積を実行する能力が高いと言われています。そこで、512×1024の行列と1024×2048の行列の掛け算にかかる時間を計測して確認してみます。
以下のレポジトリに動作するプログラムは置いてあります。記事中のプログラムはここからの抜粋です。
https://github.com/Satachito/matpro
前提
行列のデータは Row-major order で格納しています。
環境
- CPU
- Intel(R) Xeon(R) CPU @ 2.80GHz
- GPU
- A100
- GEFORCE RTX 3090
- T4
準備
#define M 512
#define K 1024
#define N 2048
C++
参考のために C++ でのプログラムを作成し、計測します。
void
MatPro( float* a, float* b, float* c ) {
for ( auto m = 0; m < M; m++ ) {
for ( auto n = 0; n < N; n++ ) {
auto $ = 0.;
for ( auto k = 0; k < K; k++ ) $ += a[ m * K + k ] * b[ k * N + n ];
c[ m * N + n ] = $;
}
}
}
-Ofast オプションをつけてコンパイルしたとき、実行時間は 317,639 μs でした。
CUDA Cores を使う
__global__ void
MatPro( float* a, float* b, float* c ) {
auto $ = 0.;
auto n = blockIdx.x * blockDim.x + threadIdx.x;
auto m = blockIdx.y * blockDim.y + threadIdx.y;
for ( auto k = 0; k < K; k++ ) $ += a[ m * K + k ] * b[ k * K + n ];
c[ m * N + n ] = $;
}
MatPro<<< dim3( N / 32, M / 32 ), dim3( 32, 32 ) >>>( a, b, c );
実行時間:
A100 | RTX 3090 | T4 |
---|---|---|
1,222μs | 9,329μs | 52,495μs |
Tensor Cores を半精度(16bit)データで使う
template < typename F > __global__ void
MatPro(
const half* _a
, const half* _b
, F* _c
) {
wmma::fragment< wmma::matrix_a, 16, 16, 16, half, wmma::row_major > a;
wmma::fragment< wmma::matrix_b, 16, 16, 16, half, wmma::row_major > b;
wmma::fragment< wmma::accumulator, 16, 16, 16, F > c;
wmma::fill_fragment( c, 0 );
for ( auto k = 0; k < K; k += 16 ) {
wmma::load_matrix_sync( a, _a + ( blockIdx.y * K * 16 + k ), K );
wmma::load_matrix_sync( b, _b + ( k * N + blockIdx.x * 16 ), N );
wmma::mma_sync( c, a, b, c );
}
wmma::store_matrix_sync( _c + ( blockIdx.y * N * 16 + blockIdx.x * 16 ), c, N, wmma::mem_row_major );
}
MatPro< F ><<< dim3( N / 16, M / 16 ), 32 >>>( a, b, c );
実行時間(結果が単精度:32bit):
A100 | RTX 3090 | T4 |
---|---|---|
547μs | 472μs | 2035μs |
実行時間(結果が半精度:16bit):
A100 | RTX 3090 | T4 |
---|---|---|
435μs | 319μs | 1,815μs |
考察
Tensor Coresは、行列積の計算において、半精度データを使用することで、メモリの節約と計算速度の向上が可能であることを示しています。
https://github.com/Satachito/matpro
このレポジトリに密行列の積を計算するマルチCPUやSIMD(AVX-512)を使った場合などのさまざまなプログラムを入れてあります。
英文