3
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

AIとpopcountと私(年寄エンジニア)

Last updated at Posted at 2025-01-12

背景(ポエム) 

この記事は、あるエンジニアが「数を2進数表記したときの1の数え方にはたくさんの方法があって、とても素晴らしいアルゴリズムがある」という話を聞いたことから始まる。
紹介した人は当然、言語によっても性能が異なるということを知っていたうえで、アルゴリズムの重要性や面白さの説明のために情報を紹介してくれた。

どんなにハードウェアが進化してもアルゴリズムというのは超重要で、それは処理速度がオーダレベルで変わるからだ。

しかし、同時に昔最速だった実装が今日でも早いかどうかは不明だ。
となると、環境ごとに試す必要もあるし、実装して試すのが速いがめんどくさい。

めんどくさいが、今や我々にはAIがあるんだからグダグダ言わずにAIにやらせればいいじゃないか。

結論First

想像した順ではなかった。
2025年1月時点では、AIが生成するコードは、計測方法や、LOOKUPテーブルの作り方など「ん?」と感じる部分があるが、多くの場面では許容出来ると思う。

  • VSCode & Cline & Claude Sonnetでほぼ自動。エンジニアあと何年仕事あるかな。
  • 専用の関数はやはり最速
  • 高級言語は作らなければ作らないほど早い(組み込み関数をうまく使え)

Pythonで8つのアルゴリズムの性能比較結果(10,000回の繰り返し)

No. アルゴリズム 処理概要 処理時間 (ms) 最速アルゴリズムとの比率
1 Native Pythonのint.bit_count()メソッド 0.242 1.00x
2 Lookup16 16ビットごとのルックアップテーブル使用 0.778 3.21x
3 Lookup8 8ビットごとのルックアップテーブル使用 1.127 4.65x
4 String ビットシフト演算を用いて、各ビットが1を確認 1.441 5.95x
5 Builtin Pythonの組み込み関数bin()を使用して2進数文字列に変換し、'1'をcount 1.792 7.40x
6 Parallel 並列ビットカウントアルゴリズム 4.117 17.00x
7 Kernighan Brian Kernighanのアルゴリズム 5.886 24.30x
8 Shift ビットシフト演算を用いて、各ビットが1であるかどうかを順に確認 6.541 27.00x

結論Second(おまけです)

一周まわって、じゃぁ、C言語では?という興味がわくので、それもAIに「Pythonのコードと同じ内容をC言語で実装して」と依頼。
出来上がったのが以下。

まとめると

  • コンパイル方法、実行方法まで教えてくれる超親切
  • ちょっと知っている用語で無茶振りをしたらこたえてくれる「インラインアセンブラで高速化して」
  • 全然わからなくても大丈夫「これらを踏まえて独自の最速アルゴリズム作って」→SIMD命令使った爆速実装爆誕

MODE1: 最大数(0xFFFFFFFFFFFFFFFF)を10000000回繰り返し評価

No. アルゴリズム名 実行時間(ms) 相対速度 特徴
1 Inline ASM (popcnt) - Ref 18.356 1.00x 参照用のpopcnt実装
2 Inline ASM (popcnt) 18.941 1.03x ハードウェアpopcnt命令を直接使用。高速で安定
3 Native 19.510 1.06x コンパイラ組み込み関数。最適な命令を自動選択
4 Inline ASM2 (SIMD) 20.607 1.12x SSSE3 SIMD命令による並列ビットカウント。メモリアライメント最適化
5 Parallel 46.177 2.52x 並列ビット演算による定数時間処理
6 Lookup16 59.847 3.26x 16ビット単位の参照。メモリ使用量128KB
7 Lookup8 109.045 5.94x 8ビット単位の参照。メモリ使用量256B
8 Shift 1378.498 75.10x 単純な右シフトによる実装。すべてのビットを走査
9 Kernighan 1658.306 90.34x 1の数に比例した処理回数。疎な入力に効果的

MODE2: 0から10000000未満までの値を順次評価

No. アルゴリズム名 実行時間(ms) 相対速度 特徴
1 Inline ASM (popcnt) - Ref 20.764 1.00x 参照用のpopcnt実装
2 Inline ASM (popcnt) 20.901 1.01x ハードウェアpopcnt命令を直接使用。高速で安定
3 Inline ASM2 (SIMD) 21.779 1.05x SSSE3 SIMD命令による並列ビットカウント。メモリアライメント最適化
4 Native 24.051 1.16x コンパイラ組み込み関数。最適な命令を自動選択
5 Lookup16 36.424 1.75x 16ビット単位の参照。メモリ使用量128KB
6 Lookup8 48.288 2.33x 8ビット単位の参照。メモリ使用量256B
7 Parallel 50.208 2.42x 並列ビット演算による定数時間処理
8 Kernighan 242.890 11.70x 1の数に比例した処理回数。疎な入力に効果的
9 Shift 369.153 17.78x 単純な右シフトによる実装。すべてのビットを走査

注記:RefとあるのはAIが勝手に実装したInline ASMの2度目の実行の性能です。1回目と比較して格段に差がある場合は、アルゴリズムでなく環境差異を考慮せよということでしょう。(すごっ)

実装

基本的にClineに指示を出しただけなのでここでは実際に入力したプロンプトと記載する。
途中Lookupテーブルが8bitだったので、16bitも作ってなどの細かい指示もしている。

実装1

最初は何も考えずにやってほしいことを指示しました。

プロンプト
Pythonで2進数の1の数をカウントを実装するコードを生成してください。
少なくとも5種のアルゴリズムを実装し、アルゴリズムごとに関数を作成、
実行結果
正確性テスト (各アルゴリズムの結果比較):
----------------------------------------------------------------------
      数値          2進数  Kernighan   String    Shift  Builtin   Lookup
----------------------------------------------------------------------
       0            0          0        0        0        0        0
       1            1          1        1        1        1        1
       2           10          1        1        1        1        1
       3           11          2        2        2        2        2
       4          100          1        1        1        1        1
       7          111          3        3        3        3        3
       8         1000          1        1        1        1        1
      15         1111          4        4        4        4        4
      16        10000          1        1        1        1        1
      31        11111          5        5        5        5        5
      32       100000          1        1        1        1        1
      63       111111          6        6        6        6        6
      64      1000000          1        1        1        1        1
     127      1111111          7        7        7        7        7
     128     10000000          1        1        1        1        1
     255     11111111          8        8        8        8        8

実装2

指示通りに作ってくれたものの、性能比較ができなかったため追加で繰り返し処理を行えるように指示します。

プロンプト
実行時間の比較ができるようにしてください。
各関数の引数には繰り返し回数をとるようにしてください。
実行結果
正確性テスト (各アルゴリズムの結果比較):
----------------------------------------------------------------------
      数値          2進数  Kernighan   String    Shift  Builtin   Lookup
----------------------------------------------------------------------
       0            0          0        0        0        0        0
       1            1          1        1        1        1        1
       2           10          1        1        1        1        1
       3           11          2        2        2        2        2
       4          100          1        1        1        1        1
       7          111          3        3        3        3        3
       8         1000          1        1        1        1        1
      15         1111          4        4        4        4        4
      16        10000          1        1        1        1        1
      31        11111          5        5        5        5        5
      32       100000          1        1        1        1        1
      63       111111          6        6        6        6        6
      64      1000000          1        1        1        1        1
     127      1111111          7        7        7        7        7
     128     10000000          1        1        1        1        1
     255     11111111          8        8        8        8        8

性能テスト (1,000,000回の繰り返し):
--------------------------------------------------
         アルゴリズム           実行時間 (ミリ秒)
--------------------------------------------------
         String         217.241 ms  (  1.00x)
        Builtin         255.124 ms  (  1.17x)
          Shift        1596.725 ms  (  7.35x)
      Kernighan        1736.173 ms  (  7.99x)
         Lookup       38306.467 ms  (176.33x)

実装3

出力や実装結果はほぼOKですが、各アルゴリズムを実装した関数を外部から何度もCallしてしまっています。
そういった性能を見たい場合もありますが、今回はアルゴリズムの差を見たいので修正指示を出します。
また、Lookupテーブル生成が関数内にあり、Lookupテーブルアルゴリズムが性能を出せないように見えたので修正指示を出します。
おそらく最速であろうNativeの関数も追加してもらいます。

プロンプト
ルックアップテーブル方式は毎回テーブルを作成しており不利です。事前処理で作成するようにしてください。
もう一つ、int.bit_countの計測も追加してください。
各アルゴリズムの関数内部でiterations回数を繰り返すように実装を見直してください。
実行結果
確性テスト (各アルゴリズムの結果比較):
--------------------------------------------------------------------------------
      数値          2進数  Kernighan   String    Shift  Builtin   Lookup   Native
--------------------------------------------------------------------------------
       0            0          0        0        0        0        0        0
       1            1          1        1        1        1        1        1
       2           10          1        1        1        1        1        1
       3           11          2        2        2        2        2        2
       4          100          1        1        1        1        1        1
       7          111          3        3        3        3        3        3
       8         1000          1        1        1        1        1        1
      15         1111          4        4        4        4        4        4
      16        10000          1        1        1        1        1        1
      31        11111          5        5        5        5        5        5
      32       100000          1        1        1        1        1        1
      63       111111          6        6        6        6        6        6
      64      1000000          1        1        1        1        1        1
     127      1111111          7        7        7        7        7        7
     128     10000000          1        1        1        1        1        1
     255     11111111          8        8        8        8        8        8

性能テスト (10,000回の繰り返し):
アルゴリズム: 実行時間
------------------------------
Native: 0.251ms (1.00x)
Lookup: 1.069ms (4.26x)
String: 1.494ms (5.95x)
Builtin: 1.898ms (7.56x)
Kernighan: 6.067ms (24.17x)
Shift: 6.799ms (27.08x)

実装4

ふと気が付くと本来試したかったビット操作のアルゴリズムが入っていません。
Qiitaから該当数る記事のコードを抜き出して追加するよう指示します。

以下のコードと同等のアルゴリズムを新たな関数として追加してください。

def popcount(x):

# 2bitごとの組に分け、立っているビット数を2bitで表現する
x = x - ((x >> 1) & 0x5555555555555555)

# 4bit整数に 上位2bit + 下位2bit を計算した値を入れる
x = (x & 0x3333333333333333) + ((x >> 2) & 0x3333333333333333)

x = (x + (x >> 4)) & 0x0f0f0f0f0f0f0f0f # 8bitごと
x = x + (x >> 8) # 16bitごと
x = x + (x >> 16) # 32bitごと
x = x + (x >> 32) # 64bitごと = 全部の合計
return x & 0x0000007f
実行結果
正確性テスト (各アルゴリズムの結果比較):
--------------------------------------------------------------------------------
      数値          2進数  Kernighan   String    Shift  Builtin  Lookup8 Lookup16   Native Parallel
--------------------------------------------------------------------------------
       0            0          0        0        0        0        0        0        0        0
       1            1          1        1        1        1        1        1        1        1
       2           10          1        1        1        1        1        1        1        1
       3           11          2        2        2        2        2        2        2        2
       4          100          1        1        1        1        1        1        1        1
       7          111          3        3        3        3        3        3        3        3
       8         1000          1        1        1        1        1        1        1        1
      15         1111          4        4        4        4        4        4        4        4
      16        10000          1        1        1        1        1        1        1        1
      31        11111          5        5        5        5        5        5        5        5
      32       100000          1        1        1        1        1        1        1        1
      63       111111          6        6        6        6        6        6        6        6
      64      1000000          1        1        1        1        1        1        1        1
     127      1111111          7        7        7        7        7        7        7        7
     128     10000000          1        1        1        1        1        1        1        1
     255     11111111          8        8        8        8        8        8        8        8

性能テスト (10,000回の繰り返し):
アルゴリズム: 実行時間
------------------------------
Native: 0.242ms (1.00x)
Lookup16: 0.778ms (3.21x)
Lookup8: 1.127ms (4.65x)
String: 1.441ms (5.95x)
Builtin: 1.792ms (7.40x)
Parallel: 4.117ms (17.00x)
Kernighan: 5.886ms (24.30x)
Shift: 6.541ms (27.00x)

評価用コード

最後に最終コードの全体を置いておきます。

Pythonバージョン全コード
import time

# 8ビット用ルックアップテーブル(0-255)
LOOKUP_TABLE_8 = [bin(i).count('1') for i in range(256)]

# 16ビット用ルックアップテーブル(0-65535)
LOOKUP_TABLE_16 = [bin(i).count('1') for i in range(65536)]

def count_ones_kernighan(n, iterations=1):
    """
    Brian Kernighanのアルゴリズムを使用して2進数の1の数をカウント
    n & (n-1)は最下位の1ビットを削除する
    """
    result = 0
    for _ in range(iterations):
        num = n
        count = 0
        while num:
            num &= (num-1)
            count += 1
        result = count
    return result

def count_ones_string(n, iterations=1):
    """
    文字列変換を使用して2進数の1の数をカウント
    """
    result = 0
    for _ in range(iterations):
        result = str(bin(n)).count('1')
    return result

def count_ones_shift(n, iterations=1):
    """
    ビットシフトを使用して2進数の1の数をカウント
    """
    result = 0
    for _ in range(iterations):
        count = 0
        num = n
        while num:
            count += num & 1
            num >>= 1
        result = count
    return result

def count_ones_builtin(n, iterations=1):
    """
    組み込み関数bin()を使用して2進数の1の数をカウント
    """
    result = 0
    for _ in range(iterations):
        result = bin(n)[2:].count('1')
    return result

def count_ones_lookup8(n, iterations=1):
    """
    8ビットルックアップテーブルを使用して2進数の1の数をカウント
    8ビットごとに処理
    """
    result = 0
    for _ in range(iterations):
        count = 0
        num = n
        while num:
            count += LOOKUP_TABLE_8[num & 0xFF]
            num >>= 8
        result = count
    return result

def count_ones_lookup16(n, iterations=1):
    """
    16ビットルックアップテーブルを使用して2進数の1の数をカウント
    16ビットごとに処理
    """
    result = 0
    for _ in range(iterations):
        count = 0
        num = n
        while num:
            count += LOOKUP_TABLE_16[num & 0xFFFF]
            num >>= 16
        result = count
    return result

def count_ones_parallel(n, iterations=1):
    """
    並列ビットカウントアルゴリズムを使用
    複数のビットを同時に処理することで高速化
    """
    result = 0
    for _ in range(iterations):
        x = n
        # 2bitごとの組に分け、立っているビット数を2bitで表現
        x = x - ((x >> 1) & 0x5555555555555555)
        # 4bit整数に 上位2bit + 下位2bit を計算した値を入れる
        x = (x & 0x3333333333333333) + ((x >> 2) & 0x3333333333333333)
        # 8bitごとに集計
        x = (x + (x >> 4)) & 0x0f0f0f0f0f0f0f0f
        # 16bitごとに集計
        x = x + (x >> 8)
        # 32bitごとに集計
        x = x + (x >> 16)
        # 64bitごとに集計(全体の合計)
        x = x + (x >> 32)
        result = x & 0x0000007f
    return result

def count_ones_native(n, iterations=1):
    """
    Pythonのint.bit_count()メソッドを使用
    """
    result = 0
    for _ in range(iterations):
        result = n.bit_count()
    return result

def benchmark_algorithms(test_number, iterations):
    """
    各アルゴリズムの実行時間を計測
    """
    algorithms = [
        ('Kernighan', count_ones_kernighan),
        ('String', count_ones_string),
        ('Shift', count_ones_shift),
        ('Builtin', count_ones_builtin),
        ('Lookup8', count_ones_lookup8),
        ('Lookup16', count_ones_lookup16),
        ('Native', count_ones_native),
        ('Parallel', count_ones_parallel)
    ]
    
    results = []
    for name, func in algorithms:
        start_time = time.time()
        func(test_number, iterations)
        end_time = time.time()
        execution_time = (end_time - start_time) * 1000
        results.append((name, execution_time))
    
    return results

def test_bit_counting(iterations=10000):  # 繰り返し回数を減らす
    test_numbers = [0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255]
    
    print(f"\n正確性テスト (各アルゴリズムの結果比較):")
    print("-" * 80)
    print(f"{'数値':>8} {'2進数':>12} {'Kernighan':>10} {'String':>8} {'Shift':>8} {'Builtin':>8} {'Lookup8':>8} {'Lookup16':>8} {'Native':>8} {'Parallel':>8}")
    print("-" * 80)
    
    for num in test_numbers:
        binary = bin(num)[2:]
        k_count = count_ones_kernighan(num, 1)
        s_count = count_ones_string(num, 1)
        sh_count = count_ones_shift(num, 1)
        b_count = count_ones_builtin(num, 1)
        l8_count = count_ones_lookup8(num, 1)
        l16_count = count_ones_lookup16(num, 1)
        n_count = count_ones_native(num, 1)
        p_count = count_ones_parallel(num, 1)
        print(f"{num:>8} {binary:>12} {k_count:>10} {s_count:>8} {sh_count:>8} {b_count:>8} {l8_count:>8} {l16_count:>8} {n_count:>8} {p_count:>8}")
    
    print(f"\n性能テスト ({iterations:,}回の繰り返し):")
    print("アルゴリズム: 実行時間")
    print("-" * 30)
    
    # テスト用の数値を16ビットに制限
    test_number = 0xFFFF  # 16ビットすべて1
    results = benchmark_algorithms(test_number, iterations)
    
    # 結果を実行時間でソート
    results.sort(key=lambda x: x[1])
    
    # 結果を表示
    fastest_time = results[0][1]
    for name, execution_time in results:
        print(f"{name}: {execution_time:.3f}ms ({execution_time/fastest_time:.2f}x)")

if __name__ == "__main__":
    test_bit_counting()

以下は、おまけのC言語バージョンです。
アルゴリズムの性能評価をする際には私は最適化オプションを付けずに実施します。理由はコンパイラが賢く、場合によってはアルゴリズムが変わるレベルで処理が変更される可能性があるからです。

コンパイルと実行
gcc bit_count.c -o bit_count  && ./bit_count
Cバージョン全コード
#include <stdio.h>
#include <time.h>
#include <stdint.h>
#include <string.h>
#include <immintrin.h>

// 8ビット用ルックアップテーブル(0-255)
unsigned char LOOKUP_TABLE_8[256];
// 16ビット用ルックアップテーブル(0-65535)
unsigned short LOOKUP_TABLE_16[65536];

// ルックアップテーブルの初期化
void init_lookup_tables() {
    for (int i = 0; i < 256; i++) {
        int count = 0;
        int n = i;
        while (n) {
            count += n & 1;
            n >>= 1;
        }
        LOOKUP_TABLE_8[i] = count;
    }

    for (int i = 0; i < 65536; i++) {
        int count = 0;
        int n = i;
        while (n) {
            count += n & 1;
            n >>= 1;
        }
        LOOKUP_TABLE_16[i] = count;
    }
}

int count_ones_kernighan(uint64_t n, int iterations, int mode) {
    int result = 0;
    if (mode == 1) {
        for (int i = 0; i < iterations; i++) {
            uint64_t num = n;
            int count = 0;
            while (num) {
                num &= (num - 1);
                count++;
            }
            result = count;
        }
    } else {
        for (uint64_t i = 0; i < iterations; i++) {
            uint64_t num = i;
            int count = 0;
            while (num) {
                num &= (num - 1);
                count++;
            }
            result = count;
        }
    }
    return result;
}

int count_ones_shift(uint64_t n, int iterations, int mode) {
    int result = 0;
    if (mode == 1) {
        for (int i = 0; i < iterations; i++) {
            int count = 0;
            uint64_t num = n;
            while (num) {
                count += num & 1;
                num >>= 1;
            }
            result = count;
        }
    } else {
        for (uint64_t i = 0; i < iterations; i++) {
            int count = 0;
            uint64_t num = i;
            while (num) {
                count += num & 1;
                num >>= 1;
            }
            result = count;
        }
    }
    return result;
}

int count_ones_lookup8(uint64_t n, int iterations, int mode) {
    int result = 0;
    if (mode == 1) {
        for (int i = 0; i < iterations; i++) {
            int count = 0;
            uint64_t num = n;
            while (num) {
                count += LOOKUP_TABLE_8[num & 0xFF];
                num >>= 8;
            }
            result = count;
        }
    } else {
        for (uint64_t i = 0; i < iterations; i++) {
            int count = 0;
            uint64_t num = i;
            while (num) {
                count += LOOKUP_TABLE_8[num & 0xFF];
                num >>= 8;
            }
            result = count;
        }
    }
    return result;
}

int count_ones_lookup16(uint64_t n, int iterations, int mode) {
    int result = 0;
    if (mode == 1) {
        for (int i = 0; i < iterations; i++) {
            int count = 0;
            uint64_t num = n;
            while (num) {
                count += LOOKUP_TABLE_16[num & 0xFFFF];
                num >>= 16;
            }
            result = count;
        }
    } else {
        for (uint64_t i = 0; i < iterations; i++) {
            int count = 0;
            uint64_t num = i;
            while (num) {
                count += LOOKUP_TABLE_16[num & 0xFFFF];
                num >>= 16;
            }
            result = count;
        }
    }
    return result;
}

int count_ones_native(uint64_t n, int iterations, int mode) {
    int result = 0;
    if (mode == 1) {
        for (int i = 0; i < iterations; i++) {
            result = __builtin_popcountll(n);
        }
    } else {
        for (uint64_t i = 0; i < iterations; i++) {
            result = __builtin_popcountll(i);
        }
    }
    return result;
}

int count_ones_parallel(uint64_t n, int iterations, int mode) {
    int result = 0;
    if (mode == 1) {
        for (int i = 0; i < iterations; i++) {
            uint64_t x = n;
            x = x - ((x >> 1) & 0x5555555555555555ULL);
            x = (x & 0x3333333333333333ULL) + ((x >> 2) & 0x3333333333333333ULL);
            x = (x + (x >> 4)) & 0x0f0f0f0f0f0f0f0fULL;
            x = x + (x >> 8);
            x = x + (x >> 16);
            x = x + (x >> 32);
            result = x & 0x7f;
        }
    } else {
        for (uint64_t i = 0; i < iterations; i++) {
            uint64_t x = i;
            x = x - ((x >> 1) & 0x5555555555555555ULL);
            x = (x & 0x3333333333333333ULL) + ((x >> 2) & 0x3333333333333333ULL);
            x = (x + (x >> 4)) & 0x0f0f0f0f0f0f0f0fULL;
            x = x + (x >> 8);
            x = x + (x >> 16);
            x = x + (x >> 32);
            result = x & 0x7f;
        }
    }
    return result;
}

// インラインアセンブラで実装した popcnt 命令を使用
int count_ones_inline_asm(uint64_t n, int iterations, int mode) {
    int result = 0;
    if (mode == 1) {
        for (int i = 0; i < iterations; i++) {
            uint64_t count;
            __asm__ ("popcnt %1, %0"
                     : "=r" (count)
                     : "r" (n)
                    );
            result = (int)count;
        }
    } else {
        for (uint64_t i = 0; i < iterations; i++) {
            uint64_t count;
            __asm__ ("popcnt %1, %0"
                     : "=r" (count)
                     : "r" (i)
                    );
            result = (int)count;
        }
    }
    return result;
}

// SSSE3 SIMD命令を使用した改良版ビットカウント
int count_ones_inline_asm2(uint64_t n, int iterations, int mode) {
    // 16バイトアライメントされたルックアップテーブル
    static const uint8_t lookup[16] __attribute__ ((aligned (16))) = {
        0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4
    };
    static const uint8_t mask[16] __attribute__ ((aligned (16))) = {
        0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,
        0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F
    };

    int result = 0;
    if (mode == 1) {
        for (int i = 0; i < iterations; i++) {
            uint64_t count;
            __asm__ volatile(
                // 値を下位64ビットにロード
                "movq %1, %%xmm0\n\t"
                // 上位64ビットをクリア
                "pxor %%xmm1, %%xmm1\n\t"
                // 8バイトを16バイトに展開
                "punpcklbw %%xmm1, %%xmm0\n\t"
                // 4ビットごとに分離
                "movdqa %%xmm0, %%xmm1\n\t"
                "psrlw $4, %%xmm1\n\t"
                "pand %2, %%xmm0\n\t"
                "pand %2, %%xmm1\n\t"
                // ルックアップテーブルを使用して1の数をカウント
                "movdqa %3, %%xmm2\n\t"
                "movdqa %%xmm2, %%xmm3\n\t"
                "pshufb %%xmm0, %%xmm2\n\t"
                "pshufb %%xmm1, %%xmm3\n\t"
                // 結果を合算
                "paddb %%xmm3, %%xmm2\n\t"
                // バイト単位で水平加算
                "pxor %%xmm0, %%xmm0\n\t"
                "psadbw %%xmm0, %%xmm2\n\t"
                // 上位64ビットと下位64ビットを合算
                "movhlps %%xmm2, %%xmm1\n\t"
                "paddd %%xmm1, %%xmm2\n\t"
                "movd %%xmm2, %0\n\t"
                : "=r" (count)
                : "r" (n),
                  "m" (mask),
                  "m" (lookup)
                : "xmm0", "xmm1", "xmm2", "xmm3"
            );
            result = (int)count;
        }
    } else {
        for (uint64_t i = 0; i < iterations; i++) {
            uint64_t count;
            __asm__ volatile(
                "movq %1, %%xmm0\n\t"
                "pxor %%xmm1, %%xmm1\n\t"
                "punpcklbw %%xmm1, %%xmm0\n\t"
                "movdqa %%xmm0, %%xmm1\n\t"
                "psrlw $4, %%xmm1\n\t"
                "pand %2, %%xmm0\n\t"
                "pand %2, %%xmm1\n\t"
                "movdqa %3, %%xmm2\n\t"
                "movdqa %%xmm2, %%xmm3\n\t"
                "pshufb %%xmm0, %%xmm2\n\t"
                "pshufb %%xmm1, %%xmm3\n\t"
                "paddb %%xmm3, %%xmm2\n\t"
                "pxor %%xmm0, %%xmm0\n\t"
                "psadbw %%xmm0, %%xmm2\n\t"
                "movhlps %%xmm2, %%xmm1\n\t"
                "paddd %%xmm1, %%xmm2\n\t"
                "movd %%xmm2, %0\n\t"
                : "=r" (count)
                : "r" (i),
                  "m" (mask),
                  "m" (lookup)
                : "xmm0", "xmm1", "xmm2", "xmm3"
            );
            result = (int)count;
        }
    }
    return result;
}

void print_algorithm_feature(const char* name) {
    if (strcmp(name, "Inline ASM2 (SIMD)") == 0) {
        printf("SSSE3 SIMD命令による並列ビットカウント。メモリアライメント最適化\n");
    } else if (strcmp(name, "Inline ASM (popcnt)") == 0) {
        printf("ハードウェアpopcnt命令を直接使用。高速で安定\n");
    } else if (strcmp(name, "Native") == 0) {
        printf("コンパイラ組み込み関数。最適な命令を自動選択\n");
    } else if (strcmp(name, "Kernighan") == 0) {
        printf("1の数に比例した処理回数。疎な入力に効果的\n");
    } else if (strcmp(name, "Shift") == 0) {
        printf("単純な右シフトによる実装。すべてのビットを走査\n");
    } else if (strcmp(name, "Lookup8") == 0) {
        printf("8ビット単位の参照。メモリ使用量256B\n");
    } else if (strcmp(name, "Lookup16") == 0) {
        printf("16ビット単位の参照。メモリ使用量128KB\n");
    } else if (strcmp(name, "Parallel") == 0) {
        printf("並列ビット演算による定数時間処理\n");
    } else if (strcmp(name, "Inline ASM (popcnt) - Ref") == 0) {
        printf("参照用のpopcnt実装\n");
    }
}

void benchmark_algorithms(uint64_t test_number, int iterations, int mode) {
    clock_t start, end;
    double times[9];  // 各アルゴリズムの実行時間を保存

    printf("\n性能テスト (%d回の繰り返し):\n", iterations);
    printf("============================================================================================================\n");
    printf("アルゴリズム名               実行時間(ms)    相対速度    特徴\n");
    printf("============================================================================================================\n");
    fflush(stdout);

    // Inline ASM2 (SIMD)
    start = clock();
    count_ones_inline_asm2(test_number, iterations, mode);
    end = clock();
    times[0] = ((double) (end - start)) / CLOCKS_PER_SEC * 1000;

    // Inline ASM
    start = clock();
    count_ones_inline_asm(test_number, iterations, mode);
    end = clock();
    times[1] = ((double) (end - start)) / CLOCKS_PER_SEC * 1000;

    // Native
    start = clock();
    count_ones_native(test_number, iterations, mode);
    end = clock();
    times[2] = ((double) (end - start)) / CLOCKS_PER_SEC * 1000;

    // Kernighan
    start = clock();
    count_ones_kernighan(test_number, iterations, mode);
    end = clock();
    times[3] = ((double) (end - start)) / CLOCKS_PER_SEC * 1000;

    // Shift
    start = clock();
    count_ones_shift(test_number, iterations, mode);
    end = clock();
    times[4] = ((double) (end - start)) / CLOCKS_PER_SEC * 1000;

    // Lookup8
    start = clock();
    count_ones_lookup8(test_number, iterations, mode);
    end = clock();
    times[5] = ((double) (end - start)) / CLOCKS_PER_SEC * 1000;

    // Lookup16
    start = clock();
    count_ones_lookup16(test_number, iterations, mode);
    end = clock();
    times[6] = ((double) (end - start)) / CLOCKS_PER_SEC * 1000;

    // Parallel
    start = clock();
    count_ones_parallel(test_number, iterations, mode);
    end = clock();
    times[7] = ((double) (end - start)) / CLOCKS_PER_SEC * 1000;

    // (以前の) Inline ASM (popcnt) を再度計測して比較
    start = clock();
    count_ones_inline_asm(test_number, iterations, mode);
    end = clock();
    times[8] = ((double) (end - start)) / CLOCKS_PER_SEC * 1000;

    // 結果を実行時間でソート
    typedef struct {
        const char* name;
        double time;
    } Result;

    Result results[] = {
        {"Inline ASM2 (SIMD)", times[0]},
        {"Inline ASM (popcnt)", times[1]},
        {"Native", times[2]},
        {"Kernighan", times[3]},
        {"Shift", times[4]},
        {"Lookup8", times[5]},
        {"Lookup16", times[6]},
        {"Parallel", times[7]},
        {"Inline ASM (popcnt) - Ref", times[8]} // 比較用にもう一度
    };

    // バブルソート(9要素をソート)
    for (int i = 0; i < 9; i++) {
        for (int j = 0; j < 9 - i - 1; j++) {
            if (results[j].time > results[j + 1].time) {
                Result temp = results[j];
                results[j] = results[j + 1];
                results[j + 1] = temp;
            }
        }
    }

    // 結果を表示(最速の時間を1とした相対時間で表示)
    double fastest_time = results[0].time;
    for (int i = 0; i < 9; i++) {
        printf("%-25s %12.3f %10.2fx    ", 
               results[i].name,
               results[i].time,
               results[i].time / fastest_time);
        print_algorithm_feature(results[i].name);
        fflush(stdout);
    }
    printf("============================================================================================================\n\n");
    fflush(stdout);
}

void test_bit_counting(int mode) {
    uint64_t test_numbers[] = {
        0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255,  // 基本的な小さい数値
        0xFF00FF00FF00FF00ULL,                                        // パターン化された値
        0xAAAAAAAAAAAAAAAAULL,                                        // 1010...1010パターン
        0x5555555555555555ULL,                                        // 0101...0101パターン
        0xFFFFFFFFFFFFFFFFULL,                                        // すべてのビットが1
        0x8000000000000000ULL,                                        // 最上位ビットのみ1
        0x0000000000000001ULL,                                        // 最下位ビットのみ1
        0xFFFFFFFFFFFFFFFEULL,                                        // 最下位ビット以外すべて1
        0x7FFFFFFFFFFFFFFFULL,                                        // 最上位ビット以外すべて1
        0x0F0F0F0F0F0F0F0FULL,                                        // 4ビットごとの交互パターン
        0xF0F0F0F0F0F0F0F0ULL                                         // 4ビットごとの交互パターン(逆)
    };
    int test_count = sizeof(test_numbers) / sizeof(test_numbers[0]);

    printf("\n正確性テスト (各アルゴリズムの結果比較):\n");
    printf("============================================================================================================\n");
    printf("アルゴリズムの特徴と評価:\n");
    printf("============================================================================================================\n");
    printf("Kernighan: ビットを1つずつ消去。最適化された実装で、1の数だけ繰り返し\n");
    printf("Shift: 単純な右シフトによる実装。すべてのビットを走査\n");
    printf("Lookup8: 8ビット単位でテーブル参照。メモリ使用量は256バイト\n");
    printf("Lookup16: 16ビット単位でテーブル参照。メモリ使用量は128KBだが高速\n");
    printf("Native: コンパイラ組み込み関数。ハードウェア命令を使用可能\n");
    printf("Parallel: ビット並列処理による高速化。定数時間で計算可能\n");
    printf("InlineASM: CPU命令(popcnt)を直接使用。最も効率的\n");
    printf("InlineASM2: 4ビット単位のルックアップによる実装\n");
    printf("============================================================================================================\n\n");

    printf("%20s %20s %8s %8s %8s %8s %8s %8s %8s %8s\n",
           "数値", "16進数", "Kernig", "Shift", "Look8", "Look16", "Native", "Paral", "ASM", "ASM2");
    printf("============================================================================================================\n");

    char hex_str[17];
    char dec_str[21];
    for (int i = 0; i < test_count; i++) {
        uint64_t num = test_numbers[i];
        sprintf(hex_str, "%016lX", num);
        sprintf(dec_str, "%llu", (unsigned long long)num);

        int k_count = count_ones_kernighan(num, 1, mode);
        int sh_count = count_ones_shift(num, 1, mode);
        int l8_count = count_ones_lookup8(num, 1, mode);
        int l16_count = count_ones_lookup16(num, 1, mode);
        int n_count = count_ones_native(num, 1, mode);
        int p_count = count_ones_parallel(num, 1, mode);
        int asm_count = count_ones_inline_asm(num, 1, mode);
        int asm2_count = count_ones_inline_asm2(num, 1, mode);

        printf("%20s %20s %8d %8d %8d %8d %8d %8d %8d %8d\n",
               dec_str, hex_str, k_count, sh_count, l8_count, l16_count, n_count, p_count, asm_count, asm2_count);
    }

    // MODE1のベンチマークテスト
    printf("\nMODE1: 最大数(0xFFFFFFFFFFFFFFFF)を%d回繰り返し評価\n", 10000000);
    fflush(stdout);
    benchmark_algorithms(0xFFFFFFFFFFFFFFFFULL, 10000000, 1);
    
    printf("\nMODE2: 0から%d未満までの値を順次評価\n", 10000000);
    fflush(stdout);
    benchmark_algorithms(0, 10000000, 2);
    fflush(stdout);
}

int main(int argc, char *argv[]) {
    setvbuf(stdout, NULL, _IONBF, 0);  // 出力バッファリングを無効化
    init_lookup_tables();
    test_bit_counting(1);  // 正確性テストはMODE1として実行
    return 0;
}

3
4
3

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?