Help us understand the problem. What is going on with this article?

UTF-8のコードポイントはどうやってもっと高速に数えるか

More than 1 year has passed since last update.

(この記事は私の blog の http://umezawa.dyndns.info/wordpress/?p=7236 の転載です)

UTF-8のコードポイントはどうやって高速に数えるかという記事がありました。コードを眺めながらもっと速くなるんじゃないのと思ったので、やってみようと思います。

元のコードはこうなっています。

inline int32_t avx2_horizontal_sum_epi8(__m256i x)
{
    __m256i sumhi = _mm256_unpackhi_epi8(x, _mm256_setzero_si256());
    __m256i sumlo = _mm256_unpacklo_epi8(x, _mm256_setzero_si256());
    __m256i sum16x16 = _mm256_add_epi16(sumhi, sumlo);
    __m256i sum16x8 = _mm256_add_epi16(sum16x16, _mm256_permute2x128_si256(sum16x16, sum16x16, 1));
    __m256i sum16x4 = _mm256_add_epi16(sum16x8, _mm256_shuffle_epi32(sum16x8, _MM_SHUFFLE(0, 0, 2, 3)));
    uint64_t tmp = _mm256_extract_epi64(sum16x4, 0);
    int32_t result = 0;
    result += (tmp >> 0 ) & 0xffff;
    result += (tmp >> 16) & 0xffff;
    result += (tmp >> 32) & 0xffff;
    result += (tmp >> 48) & 0xffff;
    return result;
}

int64_t avx_count_utf8_codepoint(const char *p,  const char *e)
{
    // `p` must be 32B-aligned pointer
    p = static_cast<const char *>(__builtin_assume_aligned(p, 32));
    const size_t size = e - p;
    int64_t result = 0;
    for (size_t i = 0; i + 31 < size;) {
        __m256i sum = _mm256_setzero_si256();
        size_t j = 0;
        for (; j < 255 * 32 && (i + 31) + j < size; j += 32) {
            const __m256i table = _mm256_setr_epi8(
                    1, 1, 1, 1, 1, 1, 1, 1, //     .. 0x7
                    0, 0, 0, 0,             // 0x8 .. 0xB
                    1, 1, 1, 1,             // 0xC .. 0xF
                    1, 1, 1, 1, 1, 1, 1, 1, //     .. 0x7
                    0, 0, 0, 0,             // 0x8 .. 0xB
                    1, 1, 1, 1              // 0xC .. 0xF
                    );
            __m256i s = _mm256_load_si256(reinterpret_cast<const __m256i *>(p + i + j));
            s = _mm256_and_si256(_mm256_srli_epi16(s, 4), _mm256_set1_epi8(0x0F));
            s = _mm256_shuffle_epi8(table, s);
            sum = _mm256_add_epi8(sum, s);
        }
        i += j;
        result += avx2_horizontal_sum_epi8(sum);
    }
    return result;
}

なお、高速化する前に、引数や返り値をちょこっとだけ変えてあります。関数の頭で計算していた値を引数で渡したり int64_t が size_t になったりしてるだけなので、速度には影響はありません。

size_t avx_count_utf8_codepoint(const char *p, size_t sz)
{
    size_t result = 0;
    for (size_t i = 0; i + 31 < sz;) {
        __m256i sum = _mm256_setzero_si256();
        size_t j = 0;
        for (; j < 255 * 32 && (i + 31) + j < sz; j += 32) {
            const __m256i table = _mm256_setr_epi8(
                1, 1, 1, 1, 1, 1, 1, 1, //     .. 0x7
                0, 0, 0, 0,             // 0x8 .. 0xB
                1, 1, 1, 1,             // 0xC .. 0xF
                1, 1, 1, 1, 1, 1, 1, 1, //     .. 0x7
                0, 0, 0, 0,             // 0x8 .. 0xB
                1, 1, 1, 1              // 0xC .. 0xF
            );
            __m256i s = _mm256_load_si256(reinterpret_cast<const __m256i *>(p + i + j));
            s = _mm256_and_si256(_mm256_srli_epi16(s, 4), _mm256_set1_epi8(0x0F));
            s = _mm256_shuffle_epi8(table, s);
            sum = _mm256_add_epi8(sum, s);
        }
        i += j;
        result += avx2_horizontal_sum_epi8(sum);
    }
    return result;
}

ループ内の処理

このアルゴリズムでは 0x00~0x7f または 0xc0~0xff であるバイトを数えるわけですが、これは signed byte として見ると -0x40~0x7f となります。なので、 VPSHUFB でテーブルを引かなくても VPCMPGTB 命令一発で判定できます(たぶん元の記事の前の記事にある UTF-8 バリデーションのコードを改造したからこうなってるんだと思う)。この場合判定を通ったバイトは 1 ではなく 0xff (=-1) になりますが、集計する際に VPADDB ではなく VPSUBB にすれば問題ありません。

size_t opt_innermost_content(const char *p, size_t sz)
{
    size_t result = 0;
    for (size_t i = 0; i + 31 < sz;) {
        __m256i sum = _mm256_setzero_si256();
        size_t j = 0;
        for (; j < 255 * 32 && (i + 31) + j < sz; j += 32) {
            __m256i s = _mm256_load_si256(reinterpret_cast<const __m256i *>(p + i + j));
            sum = _mm256_sub_epi8(sum, _mm256_cmpgt_epi8(s, _mm256_set1_epi8(-0x41)));
        }
        i += j;
        result += avx2_horizontal_sum_epi8(sum);
    }
    return result;
}

こうするとロード、判定、集計がそれぞれ1命令 (1uOP) になります。さすがにこれ以上命令数は減らないでしょう。また、Haswell の場合 VPCMPGTB と VPADDB は両方とも port1/5 で発行できてスループットが 0.5 です。このコードは sum に VPADDB するところにループ間の依存関係がありますが、横で VPCMPGTB していてループ間の依存関係を改善しても実行ユニットが足りなくて速くならないので、ループ内のコードに関しては理論上は最速になります。多分これが一番速いと思います。

いきなり最速になってしまったので、ここからは UTF-8 は関係なく条件を満たすバイトをどう高速に数えるかという話に移ります。

ループ終了条件

内側のループの終了条件が and になっていていかにも判定が遅そうです。判定はシンプルにしましょう。ちなみに 31 を足して判定している部分がありますが、この関数は32バイト単位でしか処理しないのでやらなくても同じです。

size_t opt_innermost_content_loopend(const char *p, size_t sz)
{
    size_t result = 0;
    for (size_t i = 0; i < sz;) {
        __m256i sum = _mm256_setzero_si256();
        size_t j = 0;
        size_t limit = std::min<size_t>(255 * 32, sz - i);
        for (; j < limit; j += 32) {
            __m256i s = _mm256_load_si256(reinterpret_cast<const __m256i *>(p + i + j));
            sum = _mm256_sub_epi8(sum, _mm256_cmpgt_epi8(s, _mm256_set1_epi8(-0x41)));
        }
        i += j;
        result += avx2_horizontal_sum_epi8(sum);
    }
    return result;
}

この記事には書きませんが元のコードに適用したものも実装しておきます。

とりあえず計測

計測条件は以下の通りです。

  • CPU: Core i7-4770 @3.4GHz (Haswell) TB/EIST off
  • RAM: DDR3-1600 (PC3-12800) CL9 dual channel
  • カウント対象は 16K, 224K, 6M, 128M 。これはそれぞれ L1 (32K), L2 (256K), L3 (8M) キャッシュに収まるサイズと、収まらずにメインメモリから読む状態を計測することになります。
  • コンパイラは Clang 8.0.0 の Visual Studio 2017 integration (clang-cl)
    • オプションは /Arch:AVX2 /O2 (-mavx2 -O2 相当)
    • なんでこんな珍しい環境なのかというと上記マシンは Windows だからです。一応比較として opt_innermost_content_loopend を MS のコンパイラでコンパイルした時の結果も載せておきます。


で、結果です。上段が処理速度、下段が32バイト処理するのにかかったクロック数です。

実装\サイズ 16KiB 224KiB 6MiB 128MiB
avx_count_utf8_codepoint 34.7GB/s
3.13
33.9GB/s
3.21
29.7GB/s
3.66
15.5GB/s
7.03
avx_count_utf8_codepoint_loopend 61.3GB/s
1.77
52.5GB/s
2.07
35.9GB/s
3.03
15.8GB/s
6.86
opt_innermost_content 43.6GB/s
2.49
41.1GB/s
2.65
33.3GB/s
3.27
16.0GB/s
6.80
opt_innermost_content_loopend 100.2GB/s
1.09
77.7GB/s
1.40
38.4GB/s
2.83
16.8GB/s
6.49
opt_innermost_content_loopend (cl.exe) 72.4GB/s
1.49
64.4GB/s
1.69
37.3GB/s
2.90
16.7GB/s
6.52

下限である32バイトあたり1クロックに近い速度が出ています。また、メインメモリにアクセスする場合でも1割弱速くなっています。memtest86+ によればメインメモリは 20GB/s 出るらしいのですが、そこまでは到達できませんでした。

ちなみに元コードの128MBが元記事より速くなっていますが、メインメモリの違いによる現象だと思っています(同じ DDR3-1600 でもアクセスタイミングが違う)。

さらに速くなるかどうか

opt_innermost_content_loopend に対して Clang が吐いたバイナリを見ると8倍にループアンローリングしています。アンローリングすると分岐の割合が下がりますが、これ以上の改善はループ終了条件の効率化によるしかないので、最適化余地はかなり小さくなっているはずです。大きなデータを相手にする場合はさらにアンローリングすることも考えられますが、大きなデータは L1 キャッシュに載らずメモリ側のスループットが下がるので、やっても意味がなくなります。

というか十分速くなっちゃったし下限が32バイトあたり1クロックであることは分かっているので、あんまりやる気が出ないというのが実際のところです。速くしてもこれじゃ計測誤差に埋もれちゃうし…。(実際にはやったけど速くならなかった)

まとめ

  • UTF-8 のコードポイント数をカウントするコードを最適化しました。非常に大きなデータの場合はメインメモリで律速するためそれほど効果はありませんが、コアに近いキャッシュに収まるサイズであるほど効果が大きく出ます。
  • 実際のところ人間が頑張る部分はそれほどありませんでした。 Clang すごいっすね。

Appendix

ソースコード(Windows + Visual Studio 2017 + LLVM integration 向けなので Unix 向けには多少修正する必要があります)

次回予告 → UTF-8のコードポイントはどうやってAVX-512で高速に数えるか

umezawatakeshi
AVX-512 が実装されたパフォーマンスデスクトップ向けプロセッサを待ち続けて早幾年月
http://umezawa.dyndns.info/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした