Crit-bit Treeを実装するにあたって
今回の内容はTrie木の実装の一つであるCrit-bit Treeだ.
Trieとは
まずはTrieとは何かについて簡単に説明しよう.
Trieとは連想配列の実装の一つで、
複数の異なる文字列に効率的に保存し,検索できるようにするものだ.
複数の文字列を効率的に登録,削除ができ,それぞれの文字列に任意のものを割りあてられるものだ.
pythonで書くとこんな感じ.
# 辞書でのアクセス
a={}
a["hello"]="world"
c++で書くとこんな感じ
std::map<std::string,std::string> a;
a["hello"]="world"
で,これを実装しようと思った時,
- 線形探索
- ハッシュ
- 木
の3つぐらいがすぐに思い付くであろう.
1は論外として2のハッシュであっても
ハッシュって衝突判定まわりをきちんと作ろうと思うと結構面倒.
そもそも、ハッシュテーブル分のメモリを確保しておくなんてメモリ効率が悪いと言えなくもない.
そこで木の構造を使えば効率的に検索できると考えるのが筋であろう.
木でTrieを実装してみる
どうやって文字列を木にするかを考えてみる.
ここで文字列とは1[byte]=8[bit]の列だとしよう.
そしてこのバイト列は当然可変長であり,それを効率的に保存できる構造である必要がある.
1[byte]が取り得る値は0-255の256通りであり,
文字列とは256x256x...という風に256の積の集合であるといえる.
まあ当然こんなものを全て網羅するような巨大な配列は用意できないし,スカスカすぎる.
ではこれを木の構造で考えた時どのようなものになるだろう.
配列で検索できるようにしてみる案
一文字,1[byte]をノードに割り当てることを考える.
256通りの可能性があるから,次のようなものであろう.
typedef struct{
node* c[256];
} node;
一つのノードで256の配列を用意しないといけない.
効率悪っ!
文字そのものを保存して一致検索してみる案
じゃあ,逆に文字列そのものを保存してノードを辿って検索することを考えてみよう.
こんな感じ
typedef struct{
uint8_t c;
node** next;
} node;
node**
なのは何故だか分かるだろうか.
node** next
は子ノードへのポインタの配列だ.
子ノードは256通りあるが,その全てを網羅すると結局配列検索案と一緒になってしまうため,
必要な子ノードの数だけ用意することにする.
これって結局線形探索やんか!
実際にその通りで,効率が悪い.
亜種としてこれを二分木に対応させる手法もあって,この場合は配列を持つ必要はない.
その場合はこんな感じ.
typedef struct{
uint8_t c;
node* sibling;
node* child;
} node;
これは文字列の先頭からの距離が等しい文字を兄弟とし,
その次の文字を子として保存する.
結局,子ノードを線形探索するのか,兄弟を線形探索するのかだけの違いで,ほぼ同じだ.
今まで紹介してきたアルゴリズムは簡単に思い付くだけあって効率が悪い.
先人がもっとすごいデータ構造でTrieを実現しているが,今回はその中の一つを紹介する.
その名もCrit-bit Treeだ.
その他にもダブル配列が有名だが,これを超簡単に説明すておこう.
例のnode* c[256]
があったと思う.
このノードのほとんどはスカスカで中身がない.
したがって,ちょっとズラして別のノードの配列を重ねられるんじゃないの?というのがアイディアだ.
実際にはどんな配列を持つのかとか,どう重ねるのかとか,もっと複雑なのだがイメージはずらして重ねるだ.
Crit-bit Tree
Crit-bit TreeはD. J. Bernesteinさんが開発したアルゴリズムでおどろくほど単純だ.
Crit-bit trees
ノードの構造はこんな感じ
typedef struct{
size_t length;
node* right;
node* left;
} node;
まあ,これだけ見ると何が何だか分からないだろう.
Crit-bit Treeのアイディアはこうだ.
文字列を検索できるとは,この二つが区別できるということだ.
いやいや,それ以外の入力が来るかもしれないでしょ
そうじゃない,
Trieである文字列が木にすでに登録されているか,そもそもないのかを判定したいとする.
ここで登録されていな文字列が入った時に間違ったノードに案内されたとしても,
- それが葉でなければ登録されていないと分かる.
- 葉であっても,葉に登録された文字列を保存しておけば一致比較するのは$O(N)$でよい.
つまり,検索に失敗しても葉ノードで一致比較すればそれは分かる.
チートと思われるかもしれないが
アルゴリズムはチートしてなんぼのものであり,頭の柔軟性が物を言う.
また,どんなアルゴリズムであっても巨大メモリを確保しない限り,$O(N)$になるのは防ぎようがない.
では次に進む.
まず二つの文字列,単純のためビット列を考えよう.
1110101010111110
1111101111111111
二つの文字列を区別するということだが,どこに着目すればよいだろうか.
二つの文字列の違う所といえば,XOR的な?
若干当たっているが,もっと効率的にできる.
二つの文字列を前から比較した時,最初に異なるbitである.
これがsize_t length
である.
Bernsteinさんは.たったこれだけで二つの文字列の違いを定義してしまった.
この二つの文字列で違う部分のことをCrit-bitと言いこのアルゴリズムの名前の由来になっている.
後はそれの応用だ.
つまり,全ての文字列を対象として考えると,前から順番にそのビットを比較していく.
そして異なるビットが見つかった時,前からのbitの数をlength
に保存していく.
すると,ルートに近いほど文字列の前の部分だけに着目し,ルートから離れるほど後半部分に着目することになる.
検索する時はこんな感じ.
- まずルートノードの
length
のビットを見る. - そして,それが$0$であればright,$1$であればleftに行く(実際にはどっちでも良いが統一しておく)
- そうやって葉のノードまでいく
- $0$,$1$の判定後の子ノードがなければ,登録されていないため失敗
- 葉に到達した場合は,葉に登録されている文字列と比較
こんなアルゴリズムだ.
注意点
今まで述べてきたことは,半分これからの説明のための前振りである.
これを実装する時に気をつけないといけないのが,
長さの違う文字列をどう扱うのか?という問題だ.
いやいや大丈夫でしょ
と思ったあなた.
次のパターンを考えて下さい
111011
1110110000000
1110111111111
1110110001010
上のアルゴリズムでこれらをどうやって区別しますか?
短い文字列にパディングすると,パディング後の文字列に一致してしまうんですよ!
ここに嵌まって大変苦労しました.
だから文字列の長さ自体の情報を持たせる必要があります.
つまりこういう感じのことです.
typedef struct{
uint8_t* str;
size_t size;
} Str;
uint8_t get_c(Str* s, size_t i){
if ( s->size >= i ){
return 0;
} else {
return s->str[i];
}
}
長さの情報は保持しつつも,それ以上の長さを検索に使用する場合には$0$を返すようにしておかないと,
検索に失敗します.
ちなみに,文字列から文字を取り出す時にこんなことしていると効率が悪いのでCrit-bitを計算する時に対応させるのも手です.
size_t First1(uint8_t c) {
int lb = 4;
/* binary search */
/* 1111 0000 */
/* <- ^-> */
if (c & B_1111_0000) {
lb -= 2;
c = c & B_1111_0000;
} else {
lb += 2;
}
/* 1100 1100 */
/* <^> <^> */
if (c & B_1100_1100) {
lb -= 1;
c = c & B_1100_1100;
} else {
lb += 1;
}
/* 1010 1010 */
/* <^<^ <^<^ */
if (c & B_1010_1010) {
lb -= 1;
} else {
// lb+=1;
}
return lb;
}
int CritBit(uint8_t *t, size_t tlen, uint8_t *s, size_t slen, size_t *pos) {
/* 1. CRITBIT_SUCCESS
* target |----+---|
* source |----+-----|
*
* target |-----------+---|
* source |--------|000000|
*
* target |--------|000000|
* source |-----------+---|
*
* 2. CRITBIT_SAME
* crit bit and pos is undefined.
* target |--------|
* source |--------|
*/
size_t max = slen < tlen ? tlen : slen;
for (size_t i = 0; i < max; i++) {
uint8_t T = i < tlen ? *t : 0;
uint8_t S = i < slen ? *s : 0;
if (T != S) {
size_t f = First1(T ^ S);
*pos = i * 8 + f;
return ((T << f) & B_1000_0000) == 0 ? 0 : 1;
}
t++;
s++;
}
return -1;
}