背景(ポエム)
この記事は、あるエンジニアが「数を2進数表記したときの1の数え方にはたくさんの方法があって、とても素晴らしいアルゴリズムがある」という話を聞いたことから始まる。
紹介した人は当然、言語によっても性能が異なるということを知っていたうえで、アルゴリズムの重要性や面白さの説明のために情報を紹介してくれた。
どんなにハードウェアが進化してもアルゴリズムというのは超重要で、それは処理速度がオーダレベルで変わるからだ。
しかし、同時に昔最速だった実装が今日でも早いかどうかは不明だ。
となると、環境ごとに試す必要もあるし、実装して試すのが速いがめんどくさい。
めんどくさいが、今や我々にはAIがあるんだからグダグダ言わずにAIにやらせればいいじゃないか。
結論First
想像した順ではなかった。
2025年1月時点では、AIが生成するコードは、計測方法や、LOOKUPテーブルの作り方など「ん?」と感じる部分があるが、多くの場面では許容出来ると思う。
- VSCode & Cline & Claude Sonnetでほぼ自動。エンジニアあと何年仕事あるかな。
- 専用の関数はやはり最速
- 高級言語は作らなければ作らないほど早い(組み込み関数をうまく使え)
Pythonで8つのアルゴリズムの性能比較結果(10,000回の繰り返し)
No. | アルゴリズム | 処理概要 | 処理時間 (ms) | 最速アルゴリズムとの比率 |
---|---|---|---|---|
1 | Native | Pythonのint.bit_count() メソッド |
0.242 | 1.00x |
2 | Lookup16 | 16ビットごとのルックアップテーブル使用 | 0.778 | 3.21x |
3 | Lookup8 | 8ビットごとのルックアップテーブル使用 | 1.127 | 4.65x |
4 | String | ビットシフト演算を用いて、各ビットが1を確認 | 1.441 | 5.95x |
5 | Builtin | Pythonの組み込み関数bin() を使用して2進数文字列に変換し、'1'をcount |
1.792 | 7.40x |
6 | Parallel | 並列ビットカウントアルゴリズム | 4.117 | 17.00x |
7 | Kernighan | Brian Kernighanのアルゴリズム | 5.886 | 24.30x |
8 | Shift | ビットシフト演算を用いて、各ビットが1であるかどうかを順に確認 | 6.541 | 27.00x |
結論Second(おまけです)
一周まわって、じゃぁ、C言語では?という興味がわくので、それもAIに「Pythonのコードと同じ内容をC言語で実装して」と依頼。
出来上がったのが以下。
まとめると
- コンパイル方法、実行方法まで教えてくれる超親切
- ちょっと知っている用語で無茶振りをしたらこたえてくれる「インラインアセンブラで高速化して」
- 全然わからなくても大丈夫「これらを踏まえて独自の最速アルゴリズム作って」→SIMD命令使った爆速実装爆誕
MODE1: 最大数(0xFFFFFFFFFFFFFFFF)を10000000回繰り返し評価
No. | アルゴリズム名 | 実行時間(ms) | 相対速度 | 特徴 |
---|---|---|---|---|
1 | Inline ASM (popcnt) - Ref | 18.356 | 1.00x | 参照用のpopcnt実装 |
2 | Inline ASM (popcnt) | 18.941 | 1.03x | ハードウェアpopcnt命令を直接使用。高速で安定 |
3 | Native | 19.510 | 1.06x | コンパイラ組み込み関数。最適な命令を自動選択 |
4 | Inline ASM2 (SIMD) | 20.607 | 1.12x | SSSE3 SIMD命令による並列ビットカウント。メモリアライメント最適化 |
5 | Parallel | 46.177 | 2.52x | 並列ビット演算による定数時間処理 |
6 | Lookup16 | 59.847 | 3.26x | 16ビット単位の参照。メモリ使用量128KB |
7 | Lookup8 | 109.045 | 5.94x | 8ビット単位の参照。メモリ使用量256B |
8 | Shift | 1378.498 | 75.10x | 単純な右シフトによる実装。すべてのビットを走査 |
9 | Kernighan | 1658.306 | 90.34x | 1の数に比例した処理回数。疎な入力に効果的 |
MODE2: 0から10000000未満までの値を順次評価
No. | アルゴリズム名 | 実行時間(ms) | 相対速度 | 特徴 |
---|---|---|---|---|
1 | Inline ASM (popcnt) - Ref | 20.764 | 1.00x | 参照用のpopcnt実装 |
2 | Inline ASM (popcnt) | 20.901 | 1.01x | ハードウェアpopcnt命令を直接使用。高速で安定 |
3 | Inline ASM2 (SIMD) | 21.779 | 1.05x | SSSE3 SIMD命令による並列ビットカウント。メモリアライメント最適化 |
4 | Native | 24.051 | 1.16x | コンパイラ組み込み関数。最適な命令を自動選択 |
5 | Lookup16 | 36.424 | 1.75x | 16ビット単位の参照。メモリ使用量128KB |
6 | Lookup8 | 48.288 | 2.33x | 8ビット単位の参照。メモリ使用量256B |
7 | Parallel | 50.208 | 2.42x | 並列ビット演算による定数時間処理 |
8 | Kernighan | 242.890 | 11.70x | 1の数に比例した処理回数。疎な入力に効果的 |
9 | Shift | 369.153 | 17.78x | 単純な右シフトによる実装。すべてのビットを走査 |
注記:RefとあるのはAIが勝手に実装したInline ASMの2度目の実行の性能です。1回目と比較して格段に差がある場合は、アルゴリズムでなく環境差異を考慮せよということでしょう。(すごっ)
実装
基本的にClineに指示を出しただけなのでここでは実際に入力したプロンプトと記載する。
途中Lookupテーブルが8bitだったので、16bitも作ってなどの細かい指示もしている。
実装1
最初は何も考えずにやってほしいことを指示しました。
Pythonで2進数の1の数をカウントを実装するコードを生成してください。
少なくとも5種のアルゴリズムを実装し、アルゴリズムごとに関数を作成、
正確性テスト (各アルゴリズムの結果比較):
----------------------------------------------------------------------
数値 2進数 Kernighan String Shift Builtin Lookup
----------------------------------------------------------------------
0 0 0 0 0 0 0
1 1 1 1 1 1 1
2 10 1 1 1 1 1
3 11 2 2 2 2 2
4 100 1 1 1 1 1
7 111 3 3 3 3 3
8 1000 1 1 1 1 1
15 1111 4 4 4 4 4
16 10000 1 1 1 1 1
31 11111 5 5 5 5 5
32 100000 1 1 1 1 1
63 111111 6 6 6 6 6
64 1000000 1 1 1 1 1
127 1111111 7 7 7 7 7
128 10000000 1 1 1 1 1
255 11111111 8 8 8 8 8
実装2
指示通りに作ってくれたものの、性能比較ができなかったため追加で繰り返し処理を行えるように指示します。
実行時間の比較ができるようにしてください。
各関数の引数には繰り返し回数をとるようにしてください。
正確性テスト (各アルゴリズムの結果比較):
----------------------------------------------------------------------
数値 2進数 Kernighan String Shift Builtin Lookup
----------------------------------------------------------------------
0 0 0 0 0 0 0
1 1 1 1 1 1 1
2 10 1 1 1 1 1
3 11 2 2 2 2 2
4 100 1 1 1 1 1
7 111 3 3 3 3 3
8 1000 1 1 1 1 1
15 1111 4 4 4 4 4
16 10000 1 1 1 1 1
31 11111 5 5 5 5 5
32 100000 1 1 1 1 1
63 111111 6 6 6 6 6
64 1000000 1 1 1 1 1
127 1111111 7 7 7 7 7
128 10000000 1 1 1 1 1
255 11111111 8 8 8 8 8
性能テスト (1,000,000回の繰り返し):
--------------------------------------------------
アルゴリズム 実行時間 (ミリ秒)
--------------------------------------------------
String 217.241 ms ( 1.00x)
Builtin 255.124 ms ( 1.17x)
Shift 1596.725 ms ( 7.35x)
Kernighan 1736.173 ms ( 7.99x)
Lookup 38306.467 ms (176.33x)
実装3
出力や実装結果はほぼOKですが、各アルゴリズムを実装した関数を外部から何度もCallしてしまっています。
そういった性能を見たい場合もありますが、今回はアルゴリズムの差を見たいので修正指示を出します。
また、Lookupテーブル生成が関数内にあり、Lookupテーブルアルゴリズムが性能を出せないように見えたので修正指示を出します。
おそらく最速であろうNativeの関数も追加してもらいます。
ルックアップテーブル方式は毎回テーブルを作成しており不利です。事前処理で作成するようにしてください。
もう一つ、int.bit_countの計測も追加してください。
各アルゴリズムの関数内部でiterations回数を繰り返すように実装を見直してください。
確性テスト (各アルゴリズムの結果比較):
--------------------------------------------------------------------------------
数値 2進数 Kernighan String Shift Builtin Lookup Native
--------------------------------------------------------------------------------
0 0 0 0 0 0 0 0
1 1 1 1 1 1 1 1
2 10 1 1 1 1 1 1
3 11 2 2 2 2 2 2
4 100 1 1 1 1 1 1
7 111 3 3 3 3 3 3
8 1000 1 1 1 1 1 1
15 1111 4 4 4 4 4 4
16 10000 1 1 1 1 1 1
31 11111 5 5 5 5 5 5
32 100000 1 1 1 1 1 1
63 111111 6 6 6 6 6 6
64 1000000 1 1 1 1 1 1
127 1111111 7 7 7 7 7 7
128 10000000 1 1 1 1 1 1
255 11111111 8 8 8 8 8 8
性能テスト (10,000回の繰り返し):
アルゴリズム: 実行時間
------------------------------
Native: 0.251ms (1.00x)
Lookup: 1.069ms (4.26x)
String: 1.494ms (5.95x)
Builtin: 1.898ms (7.56x)
Kernighan: 6.067ms (24.17x)
Shift: 6.799ms (27.08x)
実装4
ふと気が付くと本来試したかったビット操作のアルゴリズムが入っていません。
Qiitaから該当数る記事のコードを抜き出して追加するよう指示します。
以下のコードと同等のアルゴリズムを新たな関数として追加してください。
def popcount(x):
# 2bitごとの組に分け、立っているビット数を2bitで表現する
x = x - ((x >> 1) & 0x5555555555555555)
# 4bit整数に 上位2bit + 下位2bit を計算した値を入れる
x = (x & 0x3333333333333333) + ((x >> 2) & 0x3333333333333333)
x = (x + (x >> 4)) & 0x0f0f0f0f0f0f0f0f # 8bitごと
x = x + (x >> 8) # 16bitごと
x = x + (x >> 16) # 32bitごと
x = x + (x >> 32) # 64bitごと = 全部の合計
return x & 0x0000007f
正確性テスト (各アルゴリズムの結果比較):
--------------------------------------------------------------------------------
数値 2進数 Kernighan String Shift Builtin Lookup8 Lookup16 Native Parallel
--------------------------------------------------------------------------------
0 0 0 0 0 0 0 0 0 0
1 1 1 1 1 1 1 1 1 1
2 10 1 1 1 1 1 1 1 1
3 11 2 2 2 2 2 2 2 2
4 100 1 1 1 1 1 1 1 1
7 111 3 3 3 3 3 3 3 3
8 1000 1 1 1 1 1 1 1 1
15 1111 4 4 4 4 4 4 4 4
16 10000 1 1 1 1 1 1 1 1
31 11111 5 5 5 5 5 5 5 5
32 100000 1 1 1 1 1 1 1 1
63 111111 6 6 6 6 6 6 6 6
64 1000000 1 1 1 1 1 1 1 1
127 1111111 7 7 7 7 7 7 7 7
128 10000000 1 1 1 1 1 1 1 1
255 11111111 8 8 8 8 8 8 8 8
性能テスト (10,000回の繰り返し):
アルゴリズム: 実行時間
------------------------------
Native: 0.242ms (1.00x)
Lookup16: 0.778ms (3.21x)
Lookup8: 1.127ms (4.65x)
String: 1.441ms (5.95x)
Builtin: 1.792ms (7.40x)
Parallel: 4.117ms (17.00x)
Kernighan: 5.886ms (24.30x)
Shift: 6.541ms (27.00x)
評価用コード
最後に最終コードの全体を置いておきます。
import time
# 8ビット用ルックアップテーブル(0-255)
LOOKUP_TABLE_8 = [bin(i).count('1') for i in range(256)]
# 16ビット用ルックアップテーブル(0-65535)
LOOKUP_TABLE_16 = [bin(i).count('1') for i in range(65536)]
def count_ones_kernighan(n, iterations=1):
"""
Brian Kernighanのアルゴリズムを使用して2進数の1の数をカウント
n & (n-1)は最下位の1ビットを削除する
"""
result = 0
for _ in range(iterations):
num = n
count = 0
while num:
num &= (num-1)
count += 1
result = count
return result
def count_ones_string(n, iterations=1):
"""
文字列変換を使用して2進数の1の数をカウント
"""
result = 0
for _ in range(iterations):
result = str(bin(n)).count('1')
return result
def count_ones_shift(n, iterations=1):
"""
ビットシフトを使用して2進数の1の数をカウント
"""
result = 0
for _ in range(iterations):
count = 0
num = n
while num:
count += num & 1
num >>= 1
result = count
return result
def count_ones_builtin(n, iterations=1):
"""
組み込み関数bin()を使用して2進数の1の数をカウント
"""
result = 0
for _ in range(iterations):
result = bin(n)[2:].count('1')
return result
def count_ones_lookup8(n, iterations=1):
"""
8ビットルックアップテーブルを使用して2進数の1の数をカウント
8ビットごとに処理
"""
result = 0
for _ in range(iterations):
count = 0
num = n
while num:
count += LOOKUP_TABLE_8[num & 0xFF]
num >>= 8
result = count
return result
def count_ones_lookup16(n, iterations=1):
"""
16ビットルックアップテーブルを使用して2進数の1の数をカウント
16ビットごとに処理
"""
result = 0
for _ in range(iterations):
count = 0
num = n
while num:
count += LOOKUP_TABLE_16[num & 0xFFFF]
num >>= 16
result = count
return result
def count_ones_parallel(n, iterations=1):
"""
並列ビットカウントアルゴリズムを使用
複数のビットを同時に処理することで高速化
"""
result = 0
for _ in range(iterations):
x = n
# 2bitごとの組に分け、立っているビット数を2bitで表現
x = x - ((x >> 1) & 0x5555555555555555)
# 4bit整数に 上位2bit + 下位2bit を計算した値を入れる
x = (x & 0x3333333333333333) + ((x >> 2) & 0x3333333333333333)
# 8bitごとに集計
x = (x + (x >> 4)) & 0x0f0f0f0f0f0f0f0f
# 16bitごとに集計
x = x + (x >> 8)
# 32bitごとに集計
x = x + (x >> 16)
# 64bitごとに集計(全体の合計)
x = x + (x >> 32)
result = x & 0x0000007f
return result
def count_ones_native(n, iterations=1):
"""
Pythonのint.bit_count()メソッドを使用
"""
result = 0
for _ in range(iterations):
result = n.bit_count()
return result
def benchmark_algorithms(test_number, iterations):
"""
各アルゴリズムの実行時間を計測
"""
algorithms = [
('Kernighan', count_ones_kernighan),
('String', count_ones_string),
('Shift', count_ones_shift),
('Builtin', count_ones_builtin),
('Lookup8', count_ones_lookup8),
('Lookup16', count_ones_lookup16),
('Native', count_ones_native),
('Parallel', count_ones_parallel)
]
results = []
for name, func in algorithms:
start_time = time.time()
func(test_number, iterations)
end_time = time.time()
execution_time = (end_time - start_time) * 1000
results.append((name, execution_time))
return results
def test_bit_counting(iterations=10000): # 繰り返し回数を減らす
test_numbers = [0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255]
print(f"\n正確性テスト (各アルゴリズムの結果比較):")
print("-" * 80)
print(f"{'数値':>8} {'2進数':>12} {'Kernighan':>10} {'String':>8} {'Shift':>8} {'Builtin':>8} {'Lookup8':>8} {'Lookup16':>8} {'Native':>8} {'Parallel':>8}")
print("-" * 80)
for num in test_numbers:
binary = bin(num)[2:]
k_count = count_ones_kernighan(num, 1)
s_count = count_ones_string(num, 1)
sh_count = count_ones_shift(num, 1)
b_count = count_ones_builtin(num, 1)
l8_count = count_ones_lookup8(num, 1)
l16_count = count_ones_lookup16(num, 1)
n_count = count_ones_native(num, 1)
p_count = count_ones_parallel(num, 1)
print(f"{num:>8} {binary:>12} {k_count:>10} {s_count:>8} {sh_count:>8} {b_count:>8} {l8_count:>8} {l16_count:>8} {n_count:>8} {p_count:>8}")
print(f"\n性能テスト ({iterations:,}回の繰り返し):")
print("アルゴリズム: 実行時間")
print("-" * 30)
# テスト用の数値を16ビットに制限
test_number = 0xFFFF # 16ビットすべて1
results = benchmark_algorithms(test_number, iterations)
# 結果を実行時間でソート
results.sort(key=lambda x: x[1])
# 結果を表示
fastest_time = results[0][1]
for name, execution_time in results:
print(f"{name}: {execution_time:.3f}ms ({execution_time/fastest_time:.2f}x)")
if __name__ == "__main__":
test_bit_counting()
以下は、おまけのC言語バージョンです。
アルゴリズムの性能評価をする際には私は最適化オプションを付けずに実施します。理由はコンパイラが賢く、場合によってはアルゴリズムが変わるレベルで処理が変更される可能性があるからです。
gcc bit_count.c -o bit_count && ./bit_count
#include <stdio.h>
#include <time.h>
#include <stdint.h>
#include <string.h>
#include <immintrin.h>
// 8ビット用ルックアップテーブル(0-255)
unsigned char LOOKUP_TABLE_8[256];
// 16ビット用ルックアップテーブル(0-65535)
unsigned short LOOKUP_TABLE_16[65536];
// ルックアップテーブルの初期化
void init_lookup_tables() {
for (int i = 0; i < 256; i++) {
int count = 0;
int n = i;
while (n) {
count += n & 1;
n >>= 1;
}
LOOKUP_TABLE_8[i] = count;
}
for (int i = 0; i < 65536; i++) {
int count = 0;
int n = i;
while (n) {
count += n & 1;
n >>= 1;
}
LOOKUP_TABLE_16[i] = count;
}
}
int count_ones_kernighan(uint64_t n, int iterations, int mode) {
int result = 0;
if (mode == 1) {
for (int i = 0; i < iterations; i++) {
uint64_t num = n;
int count = 0;
while (num) {
num &= (num - 1);
count++;
}
result = count;
}
} else {
for (uint64_t i = 0; i < iterations; i++) {
uint64_t num = i;
int count = 0;
while (num) {
num &= (num - 1);
count++;
}
result = count;
}
}
return result;
}
int count_ones_shift(uint64_t n, int iterations, int mode) {
int result = 0;
if (mode == 1) {
for (int i = 0; i < iterations; i++) {
int count = 0;
uint64_t num = n;
while (num) {
count += num & 1;
num >>= 1;
}
result = count;
}
} else {
for (uint64_t i = 0; i < iterations; i++) {
int count = 0;
uint64_t num = i;
while (num) {
count += num & 1;
num >>= 1;
}
result = count;
}
}
return result;
}
int count_ones_lookup8(uint64_t n, int iterations, int mode) {
int result = 0;
if (mode == 1) {
for (int i = 0; i < iterations; i++) {
int count = 0;
uint64_t num = n;
while (num) {
count += LOOKUP_TABLE_8[num & 0xFF];
num >>= 8;
}
result = count;
}
} else {
for (uint64_t i = 0; i < iterations; i++) {
int count = 0;
uint64_t num = i;
while (num) {
count += LOOKUP_TABLE_8[num & 0xFF];
num >>= 8;
}
result = count;
}
}
return result;
}
int count_ones_lookup16(uint64_t n, int iterations, int mode) {
int result = 0;
if (mode == 1) {
for (int i = 0; i < iterations; i++) {
int count = 0;
uint64_t num = n;
while (num) {
count += LOOKUP_TABLE_16[num & 0xFFFF];
num >>= 16;
}
result = count;
}
} else {
for (uint64_t i = 0; i < iterations; i++) {
int count = 0;
uint64_t num = i;
while (num) {
count += LOOKUP_TABLE_16[num & 0xFFFF];
num >>= 16;
}
result = count;
}
}
return result;
}
int count_ones_native(uint64_t n, int iterations, int mode) {
int result = 0;
if (mode == 1) {
for (int i = 0; i < iterations; i++) {
result = __builtin_popcountll(n);
}
} else {
for (uint64_t i = 0; i < iterations; i++) {
result = __builtin_popcountll(i);
}
}
return result;
}
int count_ones_parallel(uint64_t n, int iterations, int mode) {
int result = 0;
if (mode == 1) {
for (int i = 0; i < iterations; i++) {
uint64_t x = n;
x = x - ((x >> 1) & 0x5555555555555555ULL);
x = (x & 0x3333333333333333ULL) + ((x >> 2) & 0x3333333333333333ULL);
x = (x + (x >> 4)) & 0x0f0f0f0f0f0f0f0fULL;
x = x + (x >> 8);
x = x + (x >> 16);
x = x + (x >> 32);
result = x & 0x7f;
}
} else {
for (uint64_t i = 0; i < iterations; i++) {
uint64_t x = i;
x = x - ((x >> 1) & 0x5555555555555555ULL);
x = (x & 0x3333333333333333ULL) + ((x >> 2) & 0x3333333333333333ULL);
x = (x + (x >> 4)) & 0x0f0f0f0f0f0f0f0fULL;
x = x + (x >> 8);
x = x + (x >> 16);
x = x + (x >> 32);
result = x & 0x7f;
}
}
return result;
}
// インラインアセンブラで実装した popcnt 命令を使用
int count_ones_inline_asm(uint64_t n, int iterations, int mode) {
int result = 0;
if (mode == 1) {
for (int i = 0; i < iterations; i++) {
uint64_t count;
__asm__ ("popcnt %1, %0"
: "=r" (count)
: "r" (n)
);
result = (int)count;
}
} else {
for (uint64_t i = 0; i < iterations; i++) {
uint64_t count;
__asm__ ("popcnt %1, %0"
: "=r" (count)
: "r" (i)
);
result = (int)count;
}
}
return result;
}
// SSSE3 SIMD命令を使用した改良版ビットカウント
int count_ones_inline_asm2(uint64_t n, int iterations, int mode) {
// 16バイトアライメントされたルックアップテーブル
static const uint8_t lookup[16] __attribute__ ((aligned (16))) = {
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4
};
static const uint8_t mask[16] __attribute__ ((aligned (16))) = {
0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,
0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F
};
int result = 0;
if (mode == 1) {
for (int i = 0; i < iterations; i++) {
uint64_t count;
__asm__ volatile(
// 値を下位64ビットにロード
"movq %1, %%xmm0\n\t"
// 上位64ビットをクリア
"pxor %%xmm1, %%xmm1\n\t"
// 8バイトを16バイトに展開
"punpcklbw %%xmm1, %%xmm0\n\t"
// 4ビットごとに分離
"movdqa %%xmm0, %%xmm1\n\t"
"psrlw $4, %%xmm1\n\t"
"pand %2, %%xmm0\n\t"
"pand %2, %%xmm1\n\t"
// ルックアップテーブルを使用して1の数をカウント
"movdqa %3, %%xmm2\n\t"
"movdqa %%xmm2, %%xmm3\n\t"
"pshufb %%xmm0, %%xmm2\n\t"
"pshufb %%xmm1, %%xmm3\n\t"
// 結果を合算
"paddb %%xmm3, %%xmm2\n\t"
// バイト単位で水平加算
"pxor %%xmm0, %%xmm0\n\t"
"psadbw %%xmm0, %%xmm2\n\t"
// 上位64ビットと下位64ビットを合算
"movhlps %%xmm2, %%xmm1\n\t"
"paddd %%xmm1, %%xmm2\n\t"
"movd %%xmm2, %0\n\t"
: "=r" (count)
: "r" (n),
"m" (mask),
"m" (lookup)
: "xmm0", "xmm1", "xmm2", "xmm3"
);
result = (int)count;
}
} else {
for (uint64_t i = 0; i < iterations; i++) {
uint64_t count;
__asm__ volatile(
"movq %1, %%xmm0\n\t"
"pxor %%xmm1, %%xmm1\n\t"
"punpcklbw %%xmm1, %%xmm0\n\t"
"movdqa %%xmm0, %%xmm1\n\t"
"psrlw $4, %%xmm1\n\t"
"pand %2, %%xmm0\n\t"
"pand %2, %%xmm1\n\t"
"movdqa %3, %%xmm2\n\t"
"movdqa %%xmm2, %%xmm3\n\t"
"pshufb %%xmm0, %%xmm2\n\t"
"pshufb %%xmm1, %%xmm3\n\t"
"paddb %%xmm3, %%xmm2\n\t"
"pxor %%xmm0, %%xmm0\n\t"
"psadbw %%xmm0, %%xmm2\n\t"
"movhlps %%xmm2, %%xmm1\n\t"
"paddd %%xmm1, %%xmm2\n\t"
"movd %%xmm2, %0\n\t"
: "=r" (count)
: "r" (i),
"m" (mask),
"m" (lookup)
: "xmm0", "xmm1", "xmm2", "xmm3"
);
result = (int)count;
}
}
return result;
}
void print_algorithm_feature(const char* name) {
if (strcmp(name, "Inline ASM2 (SIMD)") == 0) {
printf("SSSE3 SIMD命令による並列ビットカウント。メモリアライメント最適化\n");
} else if (strcmp(name, "Inline ASM (popcnt)") == 0) {
printf("ハードウェアpopcnt命令を直接使用。高速で安定\n");
} else if (strcmp(name, "Native") == 0) {
printf("コンパイラ組み込み関数。最適な命令を自動選択\n");
} else if (strcmp(name, "Kernighan") == 0) {
printf("1の数に比例した処理回数。疎な入力に効果的\n");
} else if (strcmp(name, "Shift") == 0) {
printf("単純な右シフトによる実装。すべてのビットを走査\n");
} else if (strcmp(name, "Lookup8") == 0) {
printf("8ビット単位の参照。メモリ使用量256B\n");
} else if (strcmp(name, "Lookup16") == 0) {
printf("16ビット単位の参照。メモリ使用量128KB\n");
} else if (strcmp(name, "Parallel") == 0) {
printf("並列ビット演算による定数時間処理\n");
} else if (strcmp(name, "Inline ASM (popcnt) - Ref") == 0) {
printf("参照用のpopcnt実装\n");
}
}
void benchmark_algorithms(uint64_t test_number, int iterations, int mode) {
clock_t start, end;
double times[9]; // 各アルゴリズムの実行時間を保存
printf("\n性能テスト (%d回の繰り返し):\n", iterations);
printf("============================================================================================================\n");
printf("アルゴリズム名 実行時間(ms) 相対速度 特徴\n");
printf("============================================================================================================\n");
fflush(stdout);
// Inline ASM2 (SIMD)
start = clock();
count_ones_inline_asm2(test_number, iterations, mode);
end = clock();
times[0] = ((double) (end - start)) / CLOCKS_PER_SEC * 1000;
// Inline ASM
start = clock();
count_ones_inline_asm(test_number, iterations, mode);
end = clock();
times[1] = ((double) (end - start)) / CLOCKS_PER_SEC * 1000;
// Native
start = clock();
count_ones_native(test_number, iterations, mode);
end = clock();
times[2] = ((double) (end - start)) / CLOCKS_PER_SEC * 1000;
// Kernighan
start = clock();
count_ones_kernighan(test_number, iterations, mode);
end = clock();
times[3] = ((double) (end - start)) / CLOCKS_PER_SEC * 1000;
// Shift
start = clock();
count_ones_shift(test_number, iterations, mode);
end = clock();
times[4] = ((double) (end - start)) / CLOCKS_PER_SEC * 1000;
// Lookup8
start = clock();
count_ones_lookup8(test_number, iterations, mode);
end = clock();
times[5] = ((double) (end - start)) / CLOCKS_PER_SEC * 1000;
// Lookup16
start = clock();
count_ones_lookup16(test_number, iterations, mode);
end = clock();
times[6] = ((double) (end - start)) / CLOCKS_PER_SEC * 1000;
// Parallel
start = clock();
count_ones_parallel(test_number, iterations, mode);
end = clock();
times[7] = ((double) (end - start)) / CLOCKS_PER_SEC * 1000;
// (以前の) Inline ASM (popcnt) を再度計測して比較
start = clock();
count_ones_inline_asm(test_number, iterations, mode);
end = clock();
times[8] = ((double) (end - start)) / CLOCKS_PER_SEC * 1000;
// 結果を実行時間でソート
typedef struct {
const char* name;
double time;
} Result;
Result results[] = {
{"Inline ASM2 (SIMD)", times[0]},
{"Inline ASM (popcnt)", times[1]},
{"Native", times[2]},
{"Kernighan", times[3]},
{"Shift", times[4]},
{"Lookup8", times[5]},
{"Lookup16", times[6]},
{"Parallel", times[7]},
{"Inline ASM (popcnt) - Ref", times[8]} // 比較用にもう一度
};
// バブルソート(9要素をソート)
for (int i = 0; i < 9; i++) {
for (int j = 0; j < 9 - i - 1; j++) {
if (results[j].time > results[j + 1].time) {
Result temp = results[j];
results[j] = results[j + 1];
results[j + 1] = temp;
}
}
}
// 結果を表示(最速の時間を1とした相対時間で表示)
double fastest_time = results[0].time;
for (int i = 0; i < 9; i++) {
printf("%-25s %12.3f %10.2fx ",
results[i].name,
results[i].time,
results[i].time / fastest_time);
print_algorithm_feature(results[i].name);
fflush(stdout);
}
printf("============================================================================================================\n\n");
fflush(stdout);
}
void test_bit_counting(int mode) {
uint64_t test_numbers[] = {
0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, // 基本的な小さい数値
0xFF00FF00FF00FF00ULL, // パターン化された値
0xAAAAAAAAAAAAAAAAULL, // 1010...1010パターン
0x5555555555555555ULL, // 0101...0101パターン
0xFFFFFFFFFFFFFFFFULL, // すべてのビットが1
0x8000000000000000ULL, // 最上位ビットのみ1
0x0000000000000001ULL, // 最下位ビットのみ1
0xFFFFFFFFFFFFFFFEULL, // 最下位ビット以外すべて1
0x7FFFFFFFFFFFFFFFULL, // 最上位ビット以外すべて1
0x0F0F0F0F0F0F0F0FULL, // 4ビットごとの交互パターン
0xF0F0F0F0F0F0F0F0ULL // 4ビットごとの交互パターン(逆)
};
int test_count = sizeof(test_numbers) / sizeof(test_numbers[0]);
printf("\n正確性テスト (各アルゴリズムの結果比較):\n");
printf("============================================================================================================\n");
printf("アルゴリズムの特徴と評価:\n");
printf("============================================================================================================\n");
printf("Kernighan: ビットを1つずつ消去。最適化された実装で、1の数だけ繰り返し\n");
printf("Shift: 単純な右シフトによる実装。すべてのビットを走査\n");
printf("Lookup8: 8ビット単位でテーブル参照。メモリ使用量は256バイト\n");
printf("Lookup16: 16ビット単位でテーブル参照。メモリ使用量は128KBだが高速\n");
printf("Native: コンパイラ組み込み関数。ハードウェア命令を使用可能\n");
printf("Parallel: ビット並列処理による高速化。定数時間で計算可能\n");
printf("InlineASM: CPU命令(popcnt)を直接使用。最も効率的\n");
printf("InlineASM2: 4ビット単位のルックアップによる実装\n");
printf("============================================================================================================\n\n");
printf("%20s %20s %8s %8s %8s %8s %8s %8s %8s %8s\n",
"数値", "16進数", "Kernig", "Shift", "Look8", "Look16", "Native", "Paral", "ASM", "ASM2");
printf("============================================================================================================\n");
char hex_str[17];
char dec_str[21];
for (int i = 0; i < test_count; i++) {
uint64_t num = test_numbers[i];
sprintf(hex_str, "%016lX", num);
sprintf(dec_str, "%llu", (unsigned long long)num);
int k_count = count_ones_kernighan(num, 1, mode);
int sh_count = count_ones_shift(num, 1, mode);
int l8_count = count_ones_lookup8(num, 1, mode);
int l16_count = count_ones_lookup16(num, 1, mode);
int n_count = count_ones_native(num, 1, mode);
int p_count = count_ones_parallel(num, 1, mode);
int asm_count = count_ones_inline_asm(num, 1, mode);
int asm2_count = count_ones_inline_asm2(num, 1, mode);
printf("%20s %20s %8d %8d %8d %8d %8d %8d %8d %8d\n",
dec_str, hex_str, k_count, sh_count, l8_count, l16_count, n_count, p_count, asm_count, asm2_count);
}
// MODE1のベンチマークテスト
printf("\nMODE1: 最大数(0xFFFFFFFFFFFFFFFF)を%d回繰り返し評価\n", 10000000);
fflush(stdout);
benchmark_algorithms(0xFFFFFFFFFFFFFFFFULL, 10000000, 1);
printf("\nMODE2: 0から%d未満までの値を順次評価\n", 10000000);
fflush(stdout);
benchmark_algorithms(0, 10000000, 2);
fflush(stdout);
}
int main(int argc, char *argv[]) {
setvbuf(stdout, NULL, _IONBF, 0); // 出力バッファリングを無効化
init_lookup_tables();
test_bit_counting(1); // 正確性テストはMODE1として実行
return 0;
}