入力されたバイナリ列がUTF-8として正当(well-formed)かバリデーションするアルゴリズムは色々な所で使われますが、高速化は難しい問題です。
先日たまたま見つけた記事・論文が、バリデーション問題を非常にうまく取り扱っていて勉強になったので、Qiita記事にまとめます。記事はJSONのパースをSIMD(x86のAVX)によって高速化したというものなのですが、部分問題として文字列の整数変換やUTF-8バリデーションについても高速実装を実現しているようです。
この記事では、取り扱う対象となるUTF-8について簡単に述べた後、使われているテクニックのいくつかを紹介し、次にアルゴリズムの解説をしていきます。
C/C++言語、AVXおよびintrinsicsについては既知のものとして記述しましたのでご注意ください。
読んだ記事・論文
- https://branchfree.org/2019/02/25/paper-parsing-gigabytes-of-json-per-second/
-
https://arxiv.org/abs/1902.08318
- 論文でいう対応する章は "3.1.5 Character-Encoding Validation" です
実装も提供されています。たくさんstarがついてますね。
https://github.com/lemire/simdjson
UTF-8についての簡単なおさらい
Unicodeを具体的なバイトの並びに符号化する形式の一つです。
1コードポイントは1-4バイトの可変長になることが特徴で、これがSIMDでの取り扱いを難しくしています。
Code Points | First Byte | Second Byte | Third Byte | Fourth Byte |
---|---|---|---|---|
U+0000..U+007F | 00..7F | |||
U+0080..U+07FF | C2..DF | 80..BF | ||
U+0800..U+0FFF | E0 | A0..BF | 80..BF | |
U+1000..U+CFFF | E1..EC | 80..BF | 80..BF | |
U+D000..U+D7FF | ED | 80..9F | 80..BF | |
U+E000..U+FFFF | EE..EF | 80..BF | 80..BF | |
U+10000..U+3FFFF | F0 | 90..BF | 80..BF | 80..BF |
U+40000..U+FFFFF | F1..F3 | 80..BF | 80..BF | 80..BF |
U+100000..U+10FFFF | F4 | 80..8F | 80..BF | 80..BF |
参照: http://www.unicode.org/versions/Unicode6.0.0/ch03.pdf の "Table 3-7. Well-Formed UTF-8 Byte Sequences"
アルゴリズムを理解する上で重要なUTF-8の特徴について述べます。
-
1コードポイントは1-4バイトのシーケンスで表現される
-
上位ニブル(1バイト8ビットのうち、上位4ビット)を確認することでシーケンスの情報が得られる
- そのバイトがシーケンス先頭バイトかどうかわかる
- もしそれがシーケンス先頭バイトだったなら、何バイトのシーケンスかわかる
-
先頭でないバイトは基本的に0x80..0xBFの範囲が許容されているが、何箇所か例外があるのでそれもバリデーションしなければならない
- 例外の箇所は表では太字で示した
- 例えば、表を見て分かるように、先頭バイトが0xE0のとき2バイト目は0xA0..0xBFのみ許容されている
- 例外は2バイト目にしかない
表で許容されている以外のバイト並びは全て不正(ill-formed)です。例えば、0xC080はNUL文字(U+0000)を無理やり2バイトで表現したものですが、こういった非最短形式のシーケンスは禁止されています。非最短形式はセキュリティリスクになるため、バリデーションでは弾かなければなりません。詳しくは↓のリンクを参照してください。
様々な技法
アルゴリズムを示す前に、内部的に技巧的な実装がいくつも使われているため、それぞれを順番に解説します。
経験のあるx86/SSE/AVXエンジニアの方であればどれも知っているものかもしれません。
上位ニブルの取り出し
全てのバイトについて、上半分の4ビットのみを抽出します。
例えばビット列 aaaa bbbb cccc dddd eeee ffff
があった時に、 0000 aaaa 0000 dddd 0000 ffff
へと変換します。
実装としては右に4ビットシフトした後に下4ビットでマスクします。1バイト単位のシフト命令はAVXに存在しませんが、今回は2バイト単位でシフトしても(マスクするので)問題ありません。
inline high_nibbles(__m256i bytes) {
return _mm256_and_si256(_mm256_srli_epi16(bytes, 4), _mm256_set1_epi8(0x0F));
}
16バイトのテーブル参照
インデックス値に対応した何らかの値を得る場合、テーブルを引くという実装方法があります。
AVXにおいては、インデックスが4ビット(16通り)までに限定されるものの、vpshufbでテーブル参照が実現できます。
例としてはちょっと極端ですが、入力の二乗を計算してみます(入力値0〜15までしか対応していません)。
__m256i avx_pow2(__m256i x) {
const __m256i table = _mm256_setr_epi8(
0, 1, 4, 9, 16, 25, 36, 49, 64, 81,100,121,144,169,196,225,
0, 1, 4, 9, 16, 25, 36, 49, 64, 81,100,121,144,169,196,225
);
return _mm256_shuffle_epi8(table, x);
}
飽和演算で値の範囲チェック
特定の値より大きいかどうかを判定する際に、飽和減算で代用することができます1。
引き算した結果が(符号なしの場合)0であればその値以下だということです。
レジスタ間のバイト単位シフト(アライメント)
あるSIMDレジスタの要素をバイト単位でずらし、空いた箇所に別のレジスタの要素をずらしてきて入れるという処理です。
vpalignrでできる……かと思いきや、この命令は128ビットのレーン単位でしか処理してくれないので、レーン間シャッフルを組み合わせます。
参考: https://qiita.com/beru/items/dbbebaf4630fb7bf4022
今回のアルゴリズムでは、特にシフト量が1ないし2のものを使います。
static inline __m256i push_last_byte_of_a_to_b(__m256i a, __m256i b) {
return _mm256_alignr_epi8(b, _mm256_permute2x128_si256(a, b, 0x21), 15);
}
static inline __m256i push_last_2bytes_of_a_to_b(__m256i a, __m256i b) {
return _mm256_alignr_epi8(b, _mm256_permute2x128_si256(a, b, 0x21), 14);
}
アルゴリズム
アルゴリズムの入力はバリデーション対象のバイト列です。実行結果として、そのバイト列がUTF-8として正当であるかを出力します。
入力はSIMDレジスタにロードされ、32バイトずつ(=ymmレジスタ長ずつ)処理します。また、各ステップで以前にロードしたバイトを使うことがあるので、前回の32バイトを捨てずに保持しておき適宜利用します(中間的な処理結果も保持しておきます)。
32バイトごとの処理は、次のステップから構成されています。
- 0xF4を超えたバイトが無いかをチェック
- 先頭バイトと後続バイトの関係性のチェック
- 0xEDが見つかったとき次のバイトが0x9Fを超えていないことのチェック。同様に0xF4と0x8Fの関係についてもチェック
- その他のバイトの関係性についてのチェック
- 先頭バイトの上位ニブルから、先頭バイトおよび2バイト目の値が許容される値より小さくないかをチェックします
これらは実装の順序がこうなっているので順番に書きましたが、実行に依存関係はありません。それどころか、 アルゴリズム中には条件分岐が含まれていません。2 したがって、それぞれのステップでエラーが見つかったかに関わらず、全てのステップが実行されます。各ステップの結果として得られた「エラーが見つかったか」フラグをbit orして「どれかのステップでエラーが見つかったか」を得ます。
以降、それぞれのステップについて見ていきます。
0xF4を超えたバイトが無いかをチェック
前述のように飽和減算を使います。エラーが見つかった場合、変数has_error
に非0が蓄積されます。
// all byte values must be no larger than 0xF4
static inline void avxcheckSmallerThan0xF4(__m256i current_bytes,
__m256i *has_error) {
// unsigned, saturates to 0 below max
*has_error = _mm256_or_si256(
*has_error, _mm256_subs_epu8(current_bytes, _mm256_set1_epi8(0xF4)));
}
先頭バイトと後続バイトの関係性のチェック
このステップでは、先頭バイトと後続バイトの関係が正しいかをチェックします。先頭バイトから期待されるシーケンス長と実際のシーケンス長が合ってなかったり、逆に先頭バイトが無いのに2バイト目以降でしか存在し得ないバイトが出現していることを検出します。
まず、次のような変換をかけます。
- 先頭バイトはそのバイトを含むシーケンス長に変換する
- 非先頭バイトは0に変換する
例えばあa
はUTF-8では0xE3 0x81 0x82 0x61
ですが、これを3 0 0 1
に変換します。
これの実装は、全てのバイトからニブルだけを抽出したもの(前述のように_mm256_srli_epi16
を使えばできます)に対してvpshufbによるテーブル参照を使います。
static inline __m256i avxcontinuationLengths(__m256i high_nibbles) {
return _mm256_shuffle_epi8(
_mm256_setr_epi8(1, 1, 1, 1, 1, 1, 1, 1, // 0xxx (ASCII)
0, 0, 0, 0, // 10xx (continuation)
2, 2, // 110x
3, // 1110
4, // 1111, next should be 0 (not checked here)
1, 1, 1, 1, 1, 1, 1, 1, // 0xxx (ASCII)
0, 0, 0, 0, // 10xx (continuation)
2, 2, // 110x
3, // 1110
4 // 1111, next should be 0 (not checked here)
),
high_nibbles);
}
変換後のバイト列に対して、次の処理を順番に行います。
- 変換後バイト列から1バイト分だけ要素をずらし1引いたものを、変換後バイト列に足す
- 変換後バイト列から2バイト分だけ要素をずらし2引いたものを、変換後バイト列に足す
前述のあa
の例でいうと、バイト列3 0 0 1
は3 2 1 1
に変換されます。
こんな変なことをする理由は「被覆」による判定のためです。この処理によって先頭バイトは(自分自身を含めた)長さの分だけ後続バイトに値を加算することになります。その結果、2バイト目の0と3バイト目の0に値が足され、非0になりました。
逆に言うと、先頭バイトが存在しない不正なシーケンスだった場合、0は0のまま足されず残ります。したがって、0が残っているかを見れば、そのバイト列が不正なシーケンスだったかどうかが判定できるわけです。
もう一つの不正な例として、本来は非先頭バイトがあるべき場所に先頭バイトがあるというパターンが考えられます。このパターンを検出するには、「元々は0でなかった要素の値が増えている」という条件式を使います。
なお、処理単位である32バイトの先頭付近の被覆を正しく判定するには、一つ前の32バイトの末尾2バイトが必要になります。そのため、前回の32バイトの処理結果を別途保持しておき、2バイト分だけずらして処理します。
static inline __m256i avxcarryContinuations(__m256i initial_lengths,
__m256i previous_carries) {
__m256i right1 = _mm256_subs_epu8(
push_last_byte_of_a_to_b(previous_carries, initial_lengths),
_mm256_set1_epi8(1));
__m256i sum = _mm256_add_epi8(initial_lengths, right1);
__m256i right2 = _mm256_subs_epu8(
push_last_2bytes_of_a_to_b(previous_carries, sum), _mm256_set1_epi8(2));
return _mm256_add_epi8(sum, right2);
}
static inline void avxcheckContinuations(__m256i initial_lengths,
__m256i carries, __m256i *has_error) {
// overlap || underlap
// carry > length && length > 0 || !(carry > length) && !(length > 0)
// (carries > length) == (lengths > 0)
__m256i overunder = _mm256_cmpeq_epi8(
_mm256_cmpgt_epi8(carries, initial_lengths),
_mm256_cmpgt_epi8(initial_lengths, _mm256_setzero_si256()));
*has_error = _mm256_or_si256(*has_error, overunder);
}
実際の実装では、バイト列の変換をavxcarryContinuations
で行い、変換後のバイト列に対しての被覆判定をavxcheckContinuations
で行なっています。ちょっと驚きですが、エラーの検出をするための条件式としては、3回の比較をしているだけです。
0xEDが見つかったとき次のバイトが0x9Fを超えていないことのチェック。同様に0xF4と0x8Fの関係についてもチェック
あまり面白くないので省略します。
1バイト分バイト単位シフトしたバイト列を作り、0xED(0xF4)と0x9F(0x8F)をそれぞれ同値判定したものでANDを取るだけです。
https://github.com/lemire/simdjson/blob/v0.1.1/include/simdjson/simdutf8check.h#L91-L109
その他のバイトの関係性についてのチェック
ここがアルゴリズム全体で最もトリッキーな箇所かもしれません。
連続する2つのバイトについて、上位ニブルから変換テーブルを引くということをします。
目的は、シーケンスの先頭バイトと2バイト目について境界チェックをすることです。チェックをするのは最小値だけです。特別な最大値(0x9Fと0x8F)については前のステップで確認済みだからです。
もちろん先頭バイトと2バイト目以外についてもテーブルによる変換が入りますが、境界値として最低値を入れておくことで、常にチェックが通すことにします(つまり、このステップでは3、4バイト目に関心がありません)。
具体的に見ていくために、UTF-8の表のうち、シーケンス先頭バイトの上位ニブル0xEの場合について再掲します。
Code Points | First Byte | Second Byte | Third Byte | Fourth Byte |
---|---|---|---|---|
U+0800..U+0FFF | E0 | A0..BF | 80..BF | |
U+1000..U+CFFF | E1..EC | 80..BF | 80..BF | |
U+D000..U+D7FF | ED | 80..9F | 80..BF | |
U+E000..U+FFFF | EE..EF | 80..BF | 80..BF |
シーケンス先頭の上位ニブルが0xEだった場合、次のチェックを行います。
- 先頭バイトが0xE1以上か(0xE0でない点に注意してください)
- 2バイト目が0xA0以上か(0x80でない点に注意してください)
2つのうち 両方を 満たせていない場合に、エラーとします。
なぜこれで過不足ないチェックになっているかは若干ややこしいですが、順番に整理します。
- 先頭バイトが0xE0のとき、2バイト目が0xA0以上かでエラーかが定まる。これはU+0800..U+0FFFの範囲チェックに他ならない
- 先頭バイトが0xE1以上のとき、2バイト目の最小値は0x80であるという以上の制約はない。これはU+1000..U+CFFFの範囲チェックに他ならない
- 先頭バイトが0xED, 0xEE, 0xEFのとき、0xE1と同様の議論でやはり範囲チェックとしては妥当
なお、そもそも2バイト目が0x80を下回ることはないのか? と言う点ですが、0x7F以下は1バイト表現なためASCIIです。そしてもし2バイト以上のシーケンス中に1バイト表現が現れていたら、前述した関係性のチェックでエラーが検出できるはずです。
以上の例示は上位ニブルが0xEだった場合ですが、0xC, 0xD, 0xFについても同様の考え方でテーブルを作ることができます。
実装を示します。先の例示の通り、テーブルのE番目にそれぞれ0xE1
、0xA0
が入っていることがわかると思います。
// map off1_hibits => error condition
// hibits off1 cur
// C => < C2 && true
// E => < E1 && < A0
// F => < F1 && < 90
// else false && false
static inline void avxcheckOverlong(__m256i current_bytes,
__m256i off1_current_bytes, __m256i hibits,
__m256i previous_hibits,
__m256i *has_error) {
__m256i off1_hibits = push_last_byte_of_a_to_b(previous_hibits, hibits);
__m256i initial_mins = _mm256_shuffle_epi8(
_mm256_setr_epi8(-128, -128, -128, -128, -128, -128, -128, -128, -128,
-128, -128, -128, // 10xx => false
0xC2, -128, // 110x
0xE1, // 1110
0xF1, -128, -128, -128, -128, -128, -128, -128, -128,
-128, -128, -128, -128, // 10xx => false
0xC2, -128, // 110x
0xE1, // 1110
0xF1),
off1_hibits);
__m256i initial_under = _mm256_cmpgt_epi8(initial_mins, off1_current_bytes);
__m256i second_mins = _mm256_shuffle_epi8(
_mm256_setr_epi8(-128, -128, -128, -128, -128, -128, -128, -128, -128,
-128, -128, -128, // 10xx => false
127, 127, // 110x => true
0xA0, // 1110
0x90, -128, -128, -128, -128, -128, -128, -128, -128,
-128, -128, -128, -128, // 10xx => false
127, 127, // 110x => true
0xA0, // 1110
0x90),
off1_hibits);
__m256i second_under = _mm256_cmpgt_epi8(second_mins, current_bytes);
*has_error = _mm256_or_si256(*has_error,
_mm256_and_si256(initial_under, second_under));
}
なお、テーブル中の定数-128の箇所は、比較時に常にfalseとなります。127は常にtrueです。
まとめ
UTF-8バリデーションのSIMD化は、正直、直感的には効果的に行うのが困難なのではと思っていましたが、論文では様々なテクニックを組み合わせて効果的に解決していました。
この実装は、もしかすると今後、色々なライブラリ・言語処理系にポートされていくのかもしれません。