UTF-8文字列からコードポイント数を計算するアルゴリズムについて紹介します。コードポイント数カウントは、シンプルに書くのはそれほど難しくないものの、高効率な実装は意外にややこしいです。
内容は二本立てです。
- 実践的な実装について、Ruby(CRuby)の内部実装(string.c)で使われているものを紹介します。
- 標準Cの範囲を超えて、SIMD命令(AVX/AVX2)を使った実装についても述べます
- 軽く検索する限りだと既知のアルゴリズムが見当たらなかったので、アドホックな実装をひねり出しましたが、そんなに効率は悪くなさそうです
おまけで簡単な性能評価をやってみました。
なお、UTF-8文字列はバリデーション済み(不正なシーケンスでないことが分かっている)であるとします。
Rubyの内部実装だとどうやっているか
まずは、それがコードポイントの先頭バイト(leading byte)かを判定するis_utf8_lead_byte
を定義します。正体はプリプロセッサマクロです。
#define is_utf8_lead_byte(c) (((c)&0xC0) != 0x80)
0xC0
つまり0b11000000
とマスクすると、上2ビットだけが残ります。UTF-8の先頭でないバイトは0b10xxxxxx
という並びになっているので、もし非先頭バイトであればマスク後には0b10000000
となりますが、これはつまり比較演算!= 0x80
で先頭バイトか判定できることを意味します。
このマクロを使うと、簡単にコードポイントカウントを実装できます。
#define is_utf8_lead_byte(c) (((c)&0xC0) != 0x80)
int64_t count_utf8_codepoint(const char *p, const char *e) {
while (p < e) {
if (is_utf8_lead_byte(*p)) len++;
p++;
}
return len;
}
さて、ここまでは単純ですが、Ruby処理系における実装はもう少し工夫がしてあります。
上記の素朴な実装ではバイト単位で文字列処理をしていますが、現代のPC・サーバなどの典型的な環境であれば32/64ビット整数を不自由なく扱えるため、1バイトごとの処理は相対的に非効率です。
そこで、複数バイトをuintptr_t
型でまとめて読み込み、含まれる全ての先頭バイトの数を一気に数えるということをします。コードポイント数は先頭バイト数に等しい ので、コードポイントを数えるには先頭バイトを数えればよいのです。
#define NONASCII_MASK UINT64_C(0x8080808080808080)
static inline uintptr_t
count_utf8_lead_bytes_with_word(const uintptr_t *s)
{
uintptr_t d = *s;
d = (d>>6) | (~d>>7);
d &= NONASCII_MASK >> 7;
return __builtin_popcountll(d);
}
※ (インクルードヘッダを除き)コード片のみでコンパイルできるように少し書き換えています。また、sizeof(uintptr_t) == 8
の環境を前提に記述しています。
まとめて計算してるのでちょっとややこしいですが、次の点に注意してください。
-
NONASCII_MASK >> 7
は各バイトの最下位ビットだけが立っているマスク-
d
はこれとbit andされるので、最下位ビットのみが残る
-
-
(d>>6) | (~d>>7)
の各バイト最下位ビットに着目すると「7ビット目が立っている or 8ビット目が落ちている」判定になっている- 7ビット目が立っている:UTF-8の先頭でないバイトは、常に
0b10xxxxxx
という並びになっていることを思い出すと、つまり先頭バイト - 8ビット目が落ちている:ASCIIなので先頭バイト
- 7ビット目が立っている:UTF-8の先頭でないバイトは、常に
ビット演算後のd
は、各バイトについて最下位ビットが立っていると先頭バイトだということを表すようになります。その他のビットはマスクされているので常に0です。したがって、含まれる先頭バイトの数を数えるには、popcntをすればいいというわけです。
最後にようやくですが、count_utf8_lead_bytes_with_word
を使った実装の全体像を示します。
int64_t ruby_count_utf8_codepoint(const char *p, const char *e)
{
uintptr_t len = 0;
if ((int)sizeof(uintptr_t) * 2 < e - p) {
const uintptr_t *s, *t;
const uintptr_t lowbits = sizeof(uintptr_t) - 1;
s = (const uintptr_t*)(~lowbits & ((uintptr_t)p + lowbits));
t = (const uintptr_t*)(~lowbits & (uintptr_t)e);
while (p < (const char *)s) {
if (is_utf8_lead_byte(*p)) len++;
p++;
}
while (s < t) {
len += count_utf8_lead_bytes_with_word(s);
s++;
}
p = (const char *)s;
}
while (p < e) {
if (is_utf8_lead_byte(*p)) len++;
p++;
}
return (long)len;
}
なんでこんなに複雑な実装なんだという話ですが、count_utf8_lead_bytes_with_word
を正しく動作させるためには、uintptr_t
のnバイト境界(典型的にはn = 4 or 8だと思います)にアライメントされたポインタから読み出す必要があるからです。具体的なやり方としては次のようになります。
-
lowbits
を足してからマスクすることで、p
をアライメントされた位置まで進めてs
とする -
lowbits
を引いてからマスクすることで、e
をアライメントされた位置まで戻してt
とする
こうして前処理されたs
, t
に対してならcount_utf8_lead_bytes_with_word
が使えます。
もちろんs
,t
とp
, e
にはズレが生じるかもしれないで、普通のループ処理を使って前後の差分を埋めています。
AVXによるアルゴリズム/実装
(2019/04/18追記) @umezawatakeshi さんにここの実装を改良してもらいました。AVX実装に興味のある方はそちらの記事も合わせて参照ください。
https://qiita.com/umezawatakeshi/items/ed23935788756c800b86
アルゴリズム、というほどのものではないので、実装と並行して解説していきます。
基本的な方針としては、UTF-8バリデーションのアルゴリズムと同じく、バイトの上位ニブルを見れば先頭バイトかどうかわかるという性質を使います。vpshufbを使えば先頭バイトのあった位置だけに1を立てることができます。
1の数え方ですが、SIMDでは水平方向に加算するのはコストがかかるので、32バイト単位で切り出されたベクトルをそのまま足し合わせていくことにします。ただし、バイトで表せる値の範囲は0〜255であることから、ベクトル加算は255回を超えて行うとオーバーフローが起こりかねません。
回避策としてはループを分割し、255回ごとに水平加算で値を集約していくことにします(もちろん入力が無くなったら255回以前にループは終了します)。こうすると水平加算は最小で1回、最大でも255回に1回しか実行されないので、相対的にコストが低くなると期待できます。
さて実装です。まず、補助関数として、バイト単位で水平加算する関数を定義しておきます。
inline int32_t avx2_horizontal_sum_epi8(__m256i x)
{
__m256i sumhi = _mm256_unpackhi_epi8(x, _mm256_setzero_si256());
__m256i sumlo = _mm256_unpacklo_epi8(x, _mm256_setzero_si256());
__m256i sum16x16 = _mm256_add_epi16(sumhi, sumlo);
__m256i sum16x8 = _mm256_add_epi16(sum16x16, _mm256_permute2x128_si256(sum16x16, sum16x16, 1));
__m256i sum16x4 = _mm256_add_epi16(sum16x8, _mm256_shuffle_epi32(sum16x8, _MM_SHUFFLE(0, 0, 2, 3)));
uint64_t tmp = _mm256_extract_epi64(sum16x4, 0);
int32_t result = 0;
result += (tmp >> 0 ) & 0xffff;
result += (tmp >> 16) & 0xffff;
result += (tmp >> 32) & 0xffff;
result += (tmp >> 48) & 0xffff;
return result;
}
avx2_horizontal_sum_epi8
を用いると、32の倍数についてコードポイントをカウントする関数は比較的簡単に書けます。
int64_t avx_count_utf8_codepoint(const char *p, const char *e)
{
// `p` must be 32B-aligned pointer
p = static_cast<const char *>(__builtin_assume_aligned(p, 32));
const size_t size = e - p;
int64_t result = 0;
for (size_t i = 0; i + 31 < size;) {
__m256i sum = _mm256_setzero_si256();
size_t j = 0;
for (; j < 255 * 32 && (i + 31) + j < size; j += 32) {
const __m256i table = _mm256_setr_epi8(
1, 1, 1, 1, 1, 1, 1, 1, // .. 0x7
0, 0, 0, 0, // 0x8 .. 0xB
1, 1, 1, 1, // 0xC .. 0xF
1, 1, 1, 1, 1, 1, 1, 1, // .. 0x7
0, 0, 0, 0, // 0x8 .. 0xB
1, 1, 1, 1 // 0xC .. 0xF
);
__m256i s = _mm256_load_si256(reinterpret_cast<const __m256i *>(p + i + j));
s = _mm256_and_si256(_mm256_srli_epi16(s, 4), _mm256_set1_epi8(0x0F));
s = _mm256_shuffle_epi8(table, s);
sum = _mm256_add_epi8(sum, s);
}
i += j;
result += avx2_horizontal_sum_epi8(sum);
}
return result;
}
内側のループでベクトルごとの加算をし、外側のループで加算後のベクトルの要素値を集約してresult
に足していきます。
table
はUTF-8の上位ニブルから先頭バイトを検出し、先頭バイトのみを1にするための変換テーブルです。
内側のループの条件式j < 255 * 32 && (i + 31) + j < size
は少しややこしいです。ループ変数j
は32ずつ加算されるので、条件式の前半j < 255 * 32
はループを255回で打ち切るという意味です。後半の(i + 31) + j < size
は、入力長さsize
を32の倍数を超えてはみ出さないためのガードです。内側のループが終了するとき、i
にj
が加算されるので外側ループのi + 31 < size
も0になり、ループは終了します。
コラボ実装
32の倍数からはみ出た部分についてはRuby実装と組み合わせることが可能です。
今回は、はみ出た部分は1バイトずつ処理する実装で埋めてみました。p
が32B境界に整列していない可能性もあるため、ベクトル処理の前に32B境界まで1バイトずつスカラで進める処理も手前に挿入しました。
※ (2019/04/07 00:30追記)バグがあったためコードの修正をしました。
int64_t count_utf8_codepoint(const char *p, const char *e)
{
int64_t count = 0;
#if defined(__AVX2__)
if (32 <= e - p) {
// increment `p` to 32B boundary
while (((uintptr_t)p % 32) != 0) {
if (is_utf8_lead_byte(*p)) count++;
p++;
}
// vectorized count
count += avx_count_utf8_codepoint(p, e);
p += static_cast<uintptr_t>(e - p) / 32 * 32;
}
#endif
while (p < e) {
if (is_utf8_lead_byte(*p)) count++;
p++;
}
return count;
}
性能評価
今回は、長い文字列に対して十分にスループットが出ているか、それは定量的にどの程度かを見ていきます。
性能評価の条件について簡単に述べます。
- 環境: Ubuntu 18.04.1 on VirtualBox on MBP Late 2013
- つまりCPUはHaswell世代
- 文字列のサイズは100MiB
- ちゃんとした文字列ではなく乱数列を使う
- 測定値としてrdtsc命令から取得したクロックサイクル数を使う
- 測定は3回行い、中央の値をとる
- 100MiBだと1回の処理が短すぎるので、同じ処理を100回した区間を測定する
- コンパイラとしてClang 8.0とGCC 7.3.0を使う
- コンパイラオプション(共通):
-O3 -masm=intel --std=c++14 -mavx -mavx2 -Wall -Wextra
- コンパイラオプション(共通):
測定対象は、紹介したスカラ版と、Ruby版と、コラボ実装の節で述べたcount_utf8_codepoint
(以下AVX版)とします。
測定結果
実装 | コンパイラ | クロックサイクル数(Mclk) |
---|---|---|
スカラ版 | GCC | 7323 |
スカラ版 | Clang | 8793 |
Ruby版 | GCC | 4573 |
Ruby版 | Clang | 2643 |
AVX版 | GCC | 2361 |
AVX版 | Clang | 2345 |
なぜこのような結果になったのか?
全体的な結果としてはスカラ版が最も遅く、AVX版が最も高速になりました。
Ruby実装はGCCとClangで大幅な差が付いていますが、これはClangの自動ベクトル化がうまく効いているからのようです。
GCCの自動ベクトル化は、Ruby版の最も重い場所(while (s < t)
のループ)について効いておらず、そこで大幅な差が付いています。
興味深いことに、スカラ版を自動ベクトル化した場合には、GCCの方が効率的なコード生成ができているように見受けられます。Ruby版のアルゴリズムからClangはヒントを得たんでしょうか……測定ミスでなければいいのですが。
同じ処理系(Clang)についてRuby版とAVX版を比較すると、AVXコードを手書きすることで1割ほど高速化できています。
しかしこれはコンパイラの支援を受けることで、比較的たやすくintrinsics手書きの9割程度の性能を得られたとも言えるため、あとはコストパフォーマンスの問題になりそうだと見受けられます。
ちなみに実時間としては、AVX版-Clangはおよそ0.838秒でした。つまり12GiB/s程度は出ていることになります。文字列が十分大きいためキャッシュは効いておらず、このMBPに付いているDDR3-1600 2chメモリは理論帯域25.6GiB/sであることを踏まえると、およそ帯域の5割ほどを使えていることになります。シングルスレッドのプログラムとしては十分すぎるような気はします。
まとめ
- UTF-8のコードポイントカウントを題材にして、実践的な実装を紹介しました
- UTF-8バリデーションアルゴリズムを応用したカウントアルゴリズムを示し、文字列が巨大である場合について実用性を検証しました
- コンパイラによって性能差・特性差が大きいことも発見しました
Appendix
ソースコード