こんな記事が出てきた。
2のべき乗サイズの配列は危ないという話 via 行列積 - elkurin’s blog
行列を連続した2次元配列で表現し、列のサイズが2のべき乗になっていると、
CPUキャッシュのタグが衝突しやすく、2のべき乗±1のときより時間がかかってしまう、という主張である。
本当にそうなるのか、試してみた。
検証環境
今回は、CS50 IDE上で実験を行った。
lscpu
コマンドによるCPUの情報は、以下のようになっていた。 (抜粋)
Model name: Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz
CPU MHz: 3599.968
L1d cache: 256 KiB
L1i cache: 256 KiB
L2 cache: 8 MiB
L3 cache: 35.8 MiB
また、計算時間は10回測定して平均を取った。
グラフ中の誤差範囲は99%信頼区間を示している。
愚直な実装
まずは何も考えずに行列の積を計算するコードを書くとやりがちな、素直な実装である。
for (i = 0; i < N; i++) {
for (j = 0; j < N; j++) {
c[i][j] = 0;
for (k = 0; k < N; k++) {
c[i][j] += a[i][k] * b[k][j];
}
}
}
この実装では、配列 b
の要素を縦方向に1個ずつアクセスするため、キャッシュが効率よく使えずに遅くなる、
というのは有名な話である。[要出典]
N=1024付近
N=1024付近で計算時間を測定した結果は、以下のようになった。
N | 計算時間 [s] |
---|---|
1022 | 2.578 |
1023 | 2.856 |
1024 | 6.646 |
1025 | 2.779 |
1026 | 2.563 |
記事の通り、N=1024の時はそのまわりに比べて計算時間が飛び抜けて長くなった。
N=2048付近
N=2048付近で計算時間を測定してみると、以下のようになった。
N | 計算時間 [s] |
---|---|
2046 | 108.464 |
2047 | 108.597 |
2048 | 111.842 |
2049 | 108.366 |
2050 | 108.415 |
N=1024の時のように数倍にはならなかったが、N=2048のときもまわりに比べて計算時間が長くなった。
ちなみに、グラフの端の部分を拡大すると、以下のようになった。
また、この行列積は$O(N^3)$なので、理論上はNの値が2倍になると計算時間は8倍程度になるはずである。
それに対し、N=2050とN=1025の比較では約39.01倍、N=2048とN=1024の比較では約16.83倍と、大幅な悪化がみられた。
N=1024付近をもう少しみてみる
N=1024付近について、実験を行う範囲を拡大してみた。
結果は、以下のようになった。
N | 計算時間 [s] | N | 計算時間 [s] | N | 計算時間 [s] |
---|---|---|---|---|---|
1014 | 2.563 | 1021 | 2.825 | 1028 | 2.435 |
1015 | 2.666 | 1022 | 2.609 | 1029 | 2.838 |
1016 | 2.227 | 1023 | 2.730 | 1030 | 2.654 |
1017 | 2.791 | 1024 | 6.695 | 1031 | 2.830 |
1018 | 2.617 | 1025 | 2.735 | 1032 | 2.275 |
1019 | 2.728 | 1026 | 2.644 | 1033 | 2.877 |
1020 | 2.320 | 1027 | 2.809 | 1034 | 2.638 |
やはりN=1024のときの計算時間が飛び抜けて長くなった。
一方、この範囲において、Nが1024以外の4の倍数のときはなぜか逆に計算時間が若干短くなった。
少し改良した実装
行列積の計算 c[i][j] += a[i][k] * b[k][j];
を外側から i,j,k
のループで回してしまうと、
配列 b
のアクセスでキャッシュを効率よく使えず、計算が遅くなってしまう。
一方、外側から i,k,j
のループにすれば、どの配列も横方向にアクセスすることになり、計算時間の短縮に繋がる。
for (i = 0; i < N; i++) {
for (j = 0; j < N; j++) {
c[i][j] = 0;
}
for (k = 0; k < N; k++) {
for (j = 0; j < N; j++) {
c[i][j] += a[i][k] * b[k][j];
}
}
}
このコードについても、同様に計算時間を測定した。
N=1024付近
N=1024付近で計算時間を測定した結果は、以下のようになった。
N | 計算時間 [s] |
---|---|
1022 | 1.360 |
1023 | 1.308 |
1024 | 1.318 |
1025 | 1.354 |
1026 | 1.295 |
愚直な実装と比べて計算時間が約半分になり、さらにN=1024においても計算時間がまわりとあまり変わらなくなった。
N=2048付近
N=2048付近で計算時間を測定してみると、以下のようになった。
N | 計算時間 [s] |
---|---|
2046 | 10.415 |
2047 | 10.504 |
2048 | 10.423 |
2049 | 10.289 |
2050 | 10.442 |
愚直な実装と比べて計算時間が約10分の1と大幅に改善し、N=2048とまわりの計算時間の差も少なくなった。
N=2050とN=1025の計算時間は約7.71倍、N=2048とN=1024の計算時間は約7.91倍と、ほぼ理論通りの比となった。
N=1024付近をもう少しみてみる
N=1024付近で範囲を広げて計算時間を測定してみると、以下のようになった。
N | 計算時間 [s] | N | 計算時間 [s] | N | 計算時間 [s] |
---|---|---|---|---|---|
1014 | 1.410 | 1021 | 1.326 | 1028 | 1.386 |
1015 | 1.327 | 1022 | 1.330 | 1029 | 1.360 |
1016 | 1.317 | 1023 | 1.333 | 1030 | 1.469 |
1017 | 1.343 | 1024 | 1.367 | 1031 | 1.381 |
1018 | 1.389 | 1025 | 1.349 | 1032 | 1.428 |
1019 | 1.328 | 1026 | 1.381 | 1033 | 1.387 |
1020 | 1.345 | 1027 | 1.354 | 1034 | 1.405 |
N=1024においてもその他の4の倍数においてもあまり差はみられず、
Nが大きくなるにつれて少しずつ計算時間が長くなる様子がみられた。
結論
キャッシュの効率が悪い行列積の計算方法においては、
配列のサイズが2の累乗のとき計算時間が伸びる現象がみられた。
一方、キャッシュを効率よく使える行列積の計算方法では、
配列のサイズが2の累乗かどうかによる計算時間への影響はみられなかった。
計算の効率を上げたい場合は、まずはキャッシュを効率よく使えるようにすることを考え、
配列のサイズの調整は、それが難しい時の次の手とするのがよいだろう。
付録
検証に使用したプログラム
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <sys/time.h>
long long get_time(void) {
struct timeval tv;
if (gettimeofday(&tv, NULL) == 0) {
return tv.tv_sec * 1000000 + tv.tv_usec;
} else {
return 0;
}
}
#ifndef N
#define N 1024
#endif
#define HASH_MULT 59879u
unsigned int a[N][N], b[N][N], c[N][N];
int main(void) {
int i, j, k;
unsigned int hash;
long long start_time, end_time;
srand(0);
for (i = 0; i < N; i++) {
for (j = 0; j < N; j++) {
a[i][j] = rand();
b[i][j] = rand();
}
}
start_time = get_time();
#ifdef VERTICAL
for (i = 0; i < N; i++) {
for (j = 0; j < N; j++) {
c[i][j] = 0;
for (k = 0; k < N; k++) {
c[i][j] += a[i][k] * b[k][j];
}
}
}
#else
for (i = 0; i < N; i++) {
for (j = 0; j < N; j++) {
c[i][j] = 0;
}
for (k = 0; k < N; k++) {
for (j = 0; j < N; j++) {
c[i][j] += a[i][k] * b[k][j];
}
}
}
#endif
end_time = get_time();
hash = 0;
for (i = 0; i < N; i++) {
for (j = 0; j < N; j++) {
hash = hash * HASH_MULT + c[i][j];
}
}
#ifdef VERTICAL
puts("vertical access mode");
#else
puts("normal mode");
#endif
printf("N = %d\n", N);
printf("hash = %u\n", hash);
printf("time = %d.%06d s\n", (int)(end_time - start_time) / 1000000, (int)(end_time - start_time) % 1000000);
return 0;
}
#!/usr/bin/perl
use strict;
use warnings;
my $run_num = 10;
my @n_list = (1022, 1023, 1024, 1025, 1026, 2046, 2047, 2048, 2049, 2050);
my @programs = ();
print STDERR "compiling programs...\n";
for (my $i = 0; $i < @n_list; $i++) {
my $n = $n_list[$i];
my $program = "./matmul-test-$n";
system("gcc -O2 -o $program matmul-test.c -DN=$n");
system("gcc -O2 -o $program-v matmul-test.c -DN=$n -DVERTICAL");
push(@programs, $program);
push(@programs, $program . "-v");
}
print STDERR "running programs...\n";
for (my $i = 0; $i < @programs; $i++) {
my $program = $programs[$i];
print STDERR "$program : ";
print $program;
for (my $j = 0; $j < $run_num; $j++) {
open(PROG, "$program |") or die("exec error\n");
unless (index(<PROG>, " mode") >= 0) { die("unexpected output\n"); }
unless (index(<PROG>, "N =") >= 0) { die("unexpected output\n"); }
unless (index(<PROG>, "hash =") >= 0) { die("unexpected output\n"); }
unless (<PROG> =~ /time = (\d+\.\d+) s/) { die("unexpected output\n"); }
print ",$1";
close(PROG);
print STDERR "*";
}
print "\n";
print STDERR "\n";
}
print STDERR "done.\n";
matmul-test.c
と matmul-test-run.pl
を同じディレクトリに置き、
matmul-test-run.pl
を実行すると、 matmul-test.c
のコンパイルと実行が行われ、
計算時間の測定結果(CSV)が標準出力に、測定の進行状況が標準エラー出力に書き出される。
範囲を広げての測定は、 matmul-test-run.pl
のうち
my @n_list = (1022, 1023, 1024, 1025, 1026, 2046, 2047, 2048, 2049, 2050);
の部分を
my @n_list = (
1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023,
1024,
1025, 1026, 1027, 1028, 1029, 1030, 1031, 1032, 1033, 1034
);
に差し替えて行った。