LoginSignup
9
6

More than 3 years have passed since last update.

2の累乗サイズの行列の積は、周辺のサイズより計算に時間がかかる?

Posted at

こんな記事が出てきた。

2のべき乗サイズの配列は危ないという話 via 行列積 - elkurin’s blog

行列を連続した2次元配列で表現し、列のサイズが2のべき乗になっていると、
CPUキャッシュのタグが衝突しやすく、2のべき乗±1のときより時間がかかってしまう、という主張である。

本当にそうなるのか、試してみた。

検証環境

今回は、CS50 IDE上で実験を行った。
lscpuコマンドによるCPUの情報は、以下のようになっていた。 (抜粋)

Model name:                      Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz
CPU MHz:                         3599.968
L1d cache:                       256 KiB
L1i cache:                       256 KiB
L2 cache:                        8 MiB
L3 cache:                        35.8 MiB

また、計算時間は10回測定して平均を取った。
グラフ中の誤差範囲は99%信頼区間を示している。

愚直な実装

まずは何も考えずに行列の積を計算するコードを書くとやりがちな、素直な実装である。

for (i = 0; i < N; i++) {
    for (j = 0; j < N; j++) {
        c[i][j] = 0;
        for (k = 0; k < N; k++) {
            c[i][j] += a[i][k] * b[k][j];
        }
    }
}

この実装では、配列 b の要素を縦方向に1個ずつアクセスするため、キャッシュが効率よく使えずに遅くなる、
というのは有名な話である。[要出典]

N=1024付近

N=1024付近で計算時間を測定した結果は、以下のようになった。

N 計算時間 [s]
1022 2.578
1023 2.856
1024 6.646
1025 2.779
1026 2.563

i,j,k のループ N=1024付近

記事の通り、N=1024の時はそのまわりに比べて計算時間が飛び抜けて長くなった。

N=2048付近

N=2048付近で計算時間を測定してみると、以下のようになった。

N 計算時間 [s]
2046 108.464
2047 108.597
2048 111.842
2049 108.366
2050 108.415

i,j,k のループ N=2048付近

N=1024の時のように数倍にはならなかったが、N=2048のときもまわりに比べて計算時間が長くなった。
ちなみに、グラフの端の部分を拡大すると、以下のようになった。

i,j,k のループ N=2048付近 (グラフの端を拡大)

また、この行列積は$O(N^3)$なので、理論上はNの値が2倍になると計算時間は8倍程度になるはずである。
それに対し、N=2050とN=1025の比較では約39.01倍、N=2048とN=1024の比較では約16.83倍と、大幅な悪化がみられた。

N=1024付近をもう少しみてみる

N=1024付近について、実験を行う範囲を拡大してみた。
結果は、以下のようになった。

N 計算時間 [s] N 計算時間 [s] N 計算時間 [s]
1014 2.563 1021 2.825 1028 2.435
1015 2.666 1022 2.609 1029 2.838
1016 2.227 1023 2.730 1030 2.654
1017 2.791 1024 6.695 1031 2.830
1018 2.617 1025 2.735 1032 2.275
1019 2.728 1026 2.644 1033 2.877
1020 2.320 1027 2.809 1034 2.638

i,j,k のループ N=1024付近 (広範囲)

やはりN=1024のときの計算時間が飛び抜けて長くなった。
一方、この範囲において、Nが1024以外の4の倍数のときはなぜか逆に計算時間が若干短くなった。

少し改良した実装

行列積の計算 c[i][j] += a[i][k] * b[k][j]; を外側から i,j,k のループで回してしまうと、
配列 b のアクセスでキャッシュを効率よく使えず、計算が遅くなってしまう。
一方、外側から i,k,j のループにすれば、どの配列も横方向にアクセスすることになり、計算時間の短縮に繋がる。

for (i = 0; i < N; i++) {
    for (j = 0; j < N; j++) {
        c[i][j] = 0;
    }
    for (k = 0; k < N; k++) {
        for (j = 0; j < N; j++) {
            c[i][j] += a[i][k] * b[k][j];
        }
    }
}

このコードについても、同様に計算時間を測定した。

N=1024付近

N=1024付近で計算時間を測定した結果は、以下のようになった。

N 計算時間 [s]
1022 1.360
1023 1.308
1024 1.318
1025 1.354
1026 1.295

i,k,j のループ N=1024付近

愚直な実装と比べて計算時間が約半分になり、さらにN=1024においても計算時間がまわりとあまり変わらなくなった。

N=2048付近

N=2048付近で計算時間を測定してみると、以下のようになった。

N 計算時間 [s]
2046 10.415
2047 10.504
2048 10.423
2049 10.289
2050 10.442

i,j,k のループ N=2048付近

愚直な実装と比べて計算時間が約10分の1と大幅に改善し、N=2048とまわりの計算時間の差も少なくなった。
N=2050とN=1025の計算時間は約7.71倍、N=2048とN=1024の計算時間は約7.91倍と、ほぼ理論通りの比となった。

N=1024付近をもう少しみてみる

N=1024付近で範囲を広げて計算時間を測定してみると、以下のようになった。

N 計算時間 [s] N 計算時間 [s] N 計算時間 [s]
1014 1.410 1021 1.326 1028 1.386
1015 1.327 1022 1.330 1029 1.360
1016 1.317 1023 1.333 1030 1.469
1017 1.343 1024 1.367 1031 1.381
1018 1.389 1025 1.349 1032 1.428
1019 1.328 1026 1.381 1033 1.387
1020 1.345 1027 1.354 1034 1.405

i,k,j のループ N=1024付近 (広範囲)

N=1024においてもその他の4の倍数においてもあまり差はみられず、
Nが大きくなるにつれて少しずつ計算時間が長くなる様子がみられた。

結論

キャッシュの効率が悪い行列積の計算方法においては、
配列のサイズが2の累乗のとき計算時間が伸びる現象がみられた。
一方、キャッシュを効率よく使える行列積の計算方法では、
配列のサイズが2の累乗かどうかによる計算時間への影響はみられなかった。

計算の効率を上げたい場合は、まずはキャッシュを効率よく使えるようにすることを考え、
配列のサイズの調整は、それが難しい時の次の手とするのがよいだろう。

付録

検証に使用したプログラム

matmul-test.c
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <sys/time.h>

long long get_time(void) {
    struct timeval tv;
    if (gettimeofday(&tv, NULL) == 0) {
        return tv.tv_sec * 1000000 + tv.tv_usec;
    } else {
        return 0;
    }
}

#ifndef N
#define N 1024
#endif

#define HASH_MULT 59879u

unsigned int a[N][N], b[N][N], c[N][N];

int main(void) {
    int i, j, k;
    unsigned int hash;
    long long start_time, end_time;
    srand(0);
    for (i = 0; i < N; i++) {
        for (j = 0; j < N; j++) {
            a[i][j] = rand();
            b[i][j] = rand();
        }
    }

    start_time = get_time();
#ifdef VERTICAL
    for (i = 0; i < N; i++) {
        for (j = 0; j < N; j++) {
            c[i][j] = 0;
            for (k = 0; k < N; k++) {
                c[i][j] += a[i][k] * b[k][j];
            }
        }
    }
#else
    for (i = 0; i < N; i++) {
        for (j = 0; j < N; j++) {
            c[i][j] = 0;
        }
        for (k = 0; k < N; k++) {
            for (j = 0; j < N; j++) {
                c[i][j] += a[i][k] * b[k][j];
            }
        }
    }
#endif
    end_time = get_time();

    hash = 0;
    for (i = 0; i < N; i++) {
        for (j = 0; j < N; j++) {
            hash = hash * HASH_MULT + c[i][j];
        }
    }
#ifdef VERTICAL
    puts("vertical access mode");
#else
    puts("normal mode");
#endif
    printf("N = %d\n", N);
    printf("hash = %u\n", hash);
    printf("time = %d.%06d s\n", (int)(end_time - start_time) / 1000000, (int)(end_time - start_time) % 1000000);
    return 0;
}
matmul-test-run.pl
#!/usr/bin/perl

use strict;
use warnings;

my $run_num = 10;
my @n_list = (1022, 1023, 1024, 1025, 1026, 2046, 2047, 2048, 2049, 2050);

my @programs = ();

print STDERR "compiling programs...\n";

for (my $i = 0; $i < @n_list; $i++) {
    my $n = $n_list[$i];
    my $program = "./matmul-test-$n";
    system("gcc -O2 -o $program matmul-test.c -DN=$n");
    system("gcc -O2 -o $program-v matmul-test.c -DN=$n -DVERTICAL");
    push(@programs, $program);
    push(@programs, $program . "-v");
}

print STDERR "running programs...\n";

for (my $i = 0; $i < @programs; $i++) {
    my $program = $programs[$i];
    print STDERR "$program : ";
    print $program;
    for (my $j = 0; $j < $run_num; $j++) {
        open(PROG, "$program |") or die("exec error\n");
        unless (index(<PROG>, " mode") >= 0) { die("unexpected output\n"); }
        unless (index(<PROG>, "N =") >= 0) { die("unexpected output\n"); }
        unless (index(<PROG>, "hash =") >= 0) { die("unexpected output\n"); }
        unless (<PROG> =~ /time = (\d+\.\d+) s/) { die("unexpected output\n"); }
        print ",$1";
        close(PROG);
        print STDERR "*";
    }
    print "\n";
    print STDERR "\n";
}

print STDERR "done.\n";

matmul-test.cmatmul-test-run.pl を同じディレクトリに置き、
matmul-test-run.pl を実行すると、 matmul-test.c のコンパイルと実行が行われ、
計算時間の測定結果(CSV)が標準出力に、測定の進行状況が標準エラー出力に書き出される。

範囲を広げての測定は、 matmul-test-run.pl のうち

my @n_list = (1022, 1023, 1024, 1025, 1026, 2046, 2047, 2048, 2049, 2050);

の部分を

my @n_list = (
    1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023,
    1024,
    1025, 1026, 1027, 1028, 1029, 1030, 1031, 1032, 1033, 1034
);

に差し替えて行った。

9
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
6