LoginSignup
0
0

More than 3 years have passed since last update.

ソースコードに探索ロジックをハードコードしたら速くなるのか

Last updated at Posted at 2020-05-16

導入

固定値のデータが必要な場合、外部ファイルを使用せずにソースコードに直接埋め込んでしまうことがあります。
データが大量にある場合はビルド前スクリプトで生成することもあります。

constexpr int table[] = {
    2,
    3,
    5,
    7,
    11,
    13,
    17,
    19,
    ...
    9973
}
bool is_primitive_table_slow(int n) {
    return find(begin(table), end(table), n) != end(table);
}

上のコードは次のように書くこともできます。

bool is_primitive_direct_slow(int n) {
    if (n == 2) return true;
    if (n == 3) return true;
    if (n == 5) return true;
    if (n == 7) return true;
    if (n == 11) return true;
    if (n == 13) return true;
    if (n == 17) return true;
    if (n == 19) return true;
    ....
    if (n == 9973) return true;
    return false;
}

通常ならばアンチパターンですが、ソースコードを自動生成する場合は禁を破っても良いのではないかと考えました。
下の直接書いたものの方が速そうですが実際どうなのか、という調査です。

ソースコード

上に加えて、

bool is_primitive_table_binary(int n) {
    return binary_search(begin(table), end(table), n);
}
bool is_primitive_direct_binary(int n) {
    if (4523 <= n) {
        if (7213 <= n) {
            ...
        } else {
            ...
        }
    } else {
        ...
                                    if (5 <= n) {
                                        if (7 <= n) {
                                            return 7 == n;
                                        } else {
                                            return 5 == n;
                                        }
                                    } else {
                                        if (3 <= n) {
                                            return 3 == n;
                                        } else {
                                            return 2 == n;
                                        }
                                    }
        ...
    }
}

も試します。

is_primitive_direct_binary の生成プログラム:

void rec(ostream &ofs, const int *begin, const int *end, int depth);

int main() {
    cout << "bool is_primitive_direct_binary(int n) {" << endl;
    rec(cout, begin(table), end(table), 1); // table は上のものと同じ
    cout << "}" << endl;
}

template<typename T>
T &with_indent(T &ofs, int n) {
    for (int i = 0; i < n; ++i) ofs << "\t";
    return ofs;
}

void rec(ostream &ofs, const int *begin, const int *end, int depth) {
    if (end - begin == 1) {
        with_indent(ofs, depth) << "return " << *begin << " == n;" << endl;
        return;
    }
    const int *mid = begin + (end - begin) / 2;
    with_indent(ofs, depth) << "if (" << *mid << " <= n) {" << endl;
    rec(ofs, mid, end, depth + 1);
    with_indent(ofs, depth) << "} else {" << endl;
    rec(ofs, begin, mid, depth + 1);
    with_indent(ofs, depth) << "}" << endl;
}

追加

bool is_primitive_direct_switch(int n) {
    switch(n) {
    case 2:
    case 3:
    case 5:
    case 7:
    case 11:
    ...
    case 9973:
        return true;
    default:
        return false;
    }
}

計測コード:

template<typename F>
void measure(const char *tag, F f) {
    auto t0 = system_clock::now();
    auto result = f();
    auto t1 = system_clock::now();
    cout << tag << ": " << setw(6) << duration_cast<milliseconds>(t1 - t0).count() << "ms" << ", result: " << result << endl;
}

int main() {
    for (int i = 0; i < 5; ++i) {
        vector<int> args;
        for (int i = 0; i < 1'000'000; ++i)
            args.push_back(rand() % 10000 + 1);

        measure(
            "table slow   ",
            [&args]() {
                int count = 0;
                for (int n : args)
                    if (is_primitive_table_slow(n))
                        ++count;
                return count;
            }
        );

        measure(
            "direct slow  ",
            [&args]() {
                int count = 0;
                for (int n : args)
                    if (is_primitive_direct_slow(n))
                        ++count;
                return count;
            }
        );

        measure(
            "table binary ",
            [&args]() {
                int count = 0;
                for (int n : args)
                    if (is_primitive_table_binary(n))
                        ++count;
                return count;
            }
        );

        measure(
            "direct binary",
            [&args]() {
                int count = 0;
                for (int n : args)
                    if (is_primitive_direct_binary(n))
                        ++count;
                return count;
            }
        );

        cout << endl;
    }
}

Visual Studio 2019
Release モード (/O2)

結果

table slow direct slow table binary direct binary direct switch
366ms 192ms 46ms 25ms 12ms
367ms 184ms 46ms 25ms 12ms
348ms 201ms 45ms 25ms 12ms
396ms 185ms 46ms 25ms 12ms
304ms 184ms 45ms 25ms 12ms

1.8倍程度速くなった。

追加
switch 文バージョンが爆速という結果になりました。
C++コンパイラはあなたよりも良いコードを書く

switch 文バージョンのアセンブラコード (追加)

(アセンブリは初めてなので間違いあったら指摘してください)

長いので省略
; n = ecx, return = al
?is_primitive_direct_switch@@YA_NH@Z PROC
    cmp ecx, 257
    jg  SHORT $LN7@is_primiti ; n < 257 => goto LN7
    je  $LN4@is_primiti       ; n == 257 => return TRUE

    add ecx, -2               ; n -= 2

    cmp ecx, 249
    ja  $LN5@is_primiti ; n < 249 => return FALSE  (2+249=251)

    ; テーブルを参照し、0ならば TRUE, 1ならば FALSE を返す
    movzx eax, BYTE PTR $LN44@is_primiti[ecx]
    jmp   DWORD PTR $LN45@is_primiti[eax*4]

$LN7@is_primiti:
    cmp ecx, 521
    jg  SHORT $LN8@is_primiti ; n < 521 => goto LN8
    je  $LN4@is_primiti       ; n == 521 => return TRUE

    add ecx, -263             ; n-=263

    cmp ecx, 246
    ja  $LN5@is_primiti       ; n < 246 => return FALSE  (263+246=509)

    ; テーブルを参照し、0ならば TRUE, 1ならば FALSE を返す
    movzx eax, BYTE PTR $LN46@is_primiti[ecx]
    jmp   DWORD PTR $LN47@is_primiti[eax*4]

$LN8@is_primiti:
    cmp ecx, 787
    jg  SHORT $LN9@is_primiti ; n < 787 => goto LN9
    je  $LN4@is_primiti       ; n == 787 => return TRUE

    add ecx, -523             ; n -= 523

    cmp ecx, 250
    ja  $LN5@is_primiti       ; n < 250 => return FALSE  (523+250=773)

    ; テーブルを参照し、0ならば TRUE, 1ならば FALSE を返す
    movzx eax, BYTE PTR $LN48@is_primiti[ecx]
    jmp   DWORD PTR $LN49@is_primiti[eax*4]

$LN9@is_primiti:
...

$LN4@is_primiti: ; TRUE を返す
    mov al, 1
    ret 0
$LN5@is_primiti: ; FALSE を返す
    xor al, al
    ret 0
    npad 2
$LN45@is_primiti:
    DD $LN4@is_primiti
    DD $LN5@is_primiti
$LN44@is_primiti:
    DB 0 ; n==2
    DB 0 ; n==3
    DB 1 ; n==4
    DB 0 ; n==5
    DB 1 ; n==6

    ...  ; 全部で 250 行

    DB 1 ; n==250
    DB 0 ; n==251
    npad 2
$LN47@is_primiti:
    DD $LN4@is_primiti
    DD $LN5@is_primiti
$LN46@is_primiti:
    DB 0 ; n==263
    DB 1
    DB 1
    DB 1
    DB 1

    ...  ; 全部で 252 行

    DB 1
    DB 0
    npad 1
$LN49@is_primiti:
    DD $LN4@is_primiti
    DD $LN5@is_primiti
$LN48@is_primiti:
    DB 0 ; n==523
    DB 1
    DB 1
    DB 1
    DB 1

    ...  ; 全部で 253 行

    DB 1
    DB 0
    npad 1
?is_primitive_direct_switch@@YA_NH@Z ENDP


つまり、256未満毎に分割されたテーブルを小さいものから順にO(1)で参照しています。
テーブル境界については直接返しています。

nが整数だからこそ取れる技であり、また、trueになる割合が小さいとバイトコードが肥大化します。
その時はコンパイラも別の手段をとるでしょう。

結論

劇的に、というわけにはいきませんでしたが、2倍近い速度差が出たのでやってみる価値はあると思いました。
二分探索バージョンのように、複雑なアルゴリズムを使用した場合は生成されるソースコードも複雑なものになるので、このコードが独り歩きしないようにコード生成環境を整備するのは必須です。

0
0
11

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0