先週の ABC403 で Trie 木の問題が出ましたね。皆さんは解けましたか? 私は解けませんでした。
せっかくなので Trie 木を自前で実装していきましょう。実は ABC353-E の ユーザー解説(Trineutron さん)に恐ろしいほどスマートな実装があります。短いので全部抜粋します。
n = int(input())
s = input().split()
tree = {}
ans = 0
for t in s:
current = tree
for c in t:
if c in current:
ans += current[c][0]
current[c][0] += 1
else:
current[c] = [1, {}]
current = current[c][1]
print(ans)
あの Trie 木がたったこれだけで!?
これの本質を抜き出すと、
- なんらかの情報(この例だとカウント)
- 次の文字への遷移(連想配列)
を持つクラスを定義すれば再帰的に Trie 木を構築できるということになります。
これを C++ でもやってみます。
struct Node {
int cnt = 0;
map<char, Node*> to;
};
これだけで定義できます。注意点として、連想配列部分は Node
ではなく Node*
(ポインタ)で持っています。
上の解法を C++ で書き直すと以下のような感じになります。
#include <iostream>
#include <map>
#include <string>
using namespace std;
struct Node {
int cnt = 0;
map<char, Node*> to;
};
int main() {
int n;
cin >> n;
Node root;
long long ans = 0;
for (int i = 0; i < n; i++) {
string s;
cin >> s;
Node* now = &root;
for (char c : s) {
if (now->to.count(c)) {
ans += now->to[c]->cnt;
} else {
now->to[c] = new Node;
}
now->to[c]->cnt++;
now = now->to[c];
}
}
cout << ans << endl;
return 0;
}
デバッグ(オイラーツアー)
木の構造を見るためにオイラーツアー的なことをやります。サンプルは アルゴロジック様 からお借りしました。
#include <iostream>
#include <string>
#include <map>
using namespace std;
struct Node {
int cnt = 0; // このノードに到達した文字列の本数
map<char, Node*> to; // 次の文字へのポインタ
};
int main() {
int n;
cin >> n;
Node root;
// 入力される n 個の文字列を Trie に挿入
for (int i = 0; i < n; ++i) {
string s;
cin >> s;
Node* now = &root;
now->cnt++; // 空文字列もカウント
for (char c : s) {
if (!now->to.count(c)) {
now->to[c] = new Node();
}
now = now->to[c];
now->cnt++;
}
}
string path;
// DFSでTrieを探索
auto dfs = [&](auto& dfs, Node* now) -> void {
cout << now->cnt << " " << path << endl;
for (auto& [key, next] : now->to) {
path.push_back(key);
dfs(dfs, next);
path.pop_back();
}
};
dfs(dfs, &root);
}
5
fire
firearm
firework
fireman
algo
5
1 a
1 al
1 alg
1 algo
4 f
4 fi
4 fir
4 fire
1 firea
1 firear
1 firearm
1 firem
1 firema
1 fireman
1 firew
1 firewo
1 firewor
1 firework
応用その1(最長共通接頭辞)
この木上でカウントが $1$ であることは、先頭~現在までが一致するような単語がただ $1$ つしかないことを意味します。逆に、カウントが $2$ 以上であることは、そのような単語が $2$ つ以上存在すること意味し、これは共通接頭辞を持つことを意味します。
このことから、カウントが $2$ 以上であるようにギリギリ伸ばしていけば、文字列集合全体での他の文字列との最長共通接頭辞の長さを調べることができます。
#include <bits/stdc++.h>
using namespace std;
// Trie のノード構造体
struct Node {
int cnt = 0; // このノードまで到達した文字列の個数
map<char, Node*> to; // 次の文字への枝(子ノードへのポインタ)
};
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n;
cin >> n;
// 入力される n 個の文字列を格納
vector<string> ss(n);
for(int i = 0; i < n; i++){
cin >> ss[i];
}
// Trie の根ノードを動的に確保
Node* root = new Node();
// 1) Trie の構築フェーズ
for(const string& s : ss){
Node* now = root;
now->cnt++; // 空文字もカウント(各文字列の先頭に到達したので +1)
for(char c : s){
// 枝がなければ新しくノードを作成
if(now->to.find(c) == now->to.end()){
now->to[c] = new Node();
}
// 子ノードへ移動してカウントをインクリメント
now = now->to[c];
now->cnt++;
}
}
// 2) 各文字列について「他の文字列と共有している最大の接頭辞長」を求めるフェーズ
for(const string& s : ss){
Node* now = root;
int idx = 0;
int m = s.size();
// 文字を一文字ずつたどり、到達ノードの cnt が 2 以上なら idx++ して深さを伸ばす
while(idx < m){
char c = s[idx];
Node* next = now->to[c];
if(next->cnt <= 1){
// この枝を通る文字列が自分しかいない → 共有はここまで
break;
}
now = next;
idx++;
}
// idx が他の文字列と共有している最大の長さ
cout << idx << "\n";
}
return 0;
}
応用その2(ノードに特殊な意味を持たせる)
禁止単語
以下のような処理をしたいです。
- 「禁止文字列」と「普通の文字列」の $2$ 種類が存在する
- 木上で「禁止文字列」より後の文字列は、追加しても意味がない
- 「禁止文字列」の追加に成功した場合→そのカウントをルートまで引き続ける
- 「普通の単語」の追加に成功した場合→ルートまでカウントを $+1$ し続ける
いままでの構造体だと情報不足なので、新たに「禁止ノードかそうでないか」の情報を設けます。
また、実装の都合上、親ノードの情報も追加します。
struct Node {
int cnt = 0;
bool banned = false;
Node* prev;
map<char, Node*> to;
};
コード全体としては以下のようになります。
#include <iostream>
#include <map>
#include <string>
#include <vector>
using namespace std;
// Trie のノード構造体
struct Node {
int cnt = 0; // このノードまで到達した文字列の個数
bool banned = false; // このノード以下を禁止するフラグ
Node* prev; // 親ノードへのポインタ(先頭は nullptr)
map<char, Node*> to; // 子ノードへの枝
};
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int q;
cin >> q;
// 根ノードを作成し、prev を nullptr に初期化
Node root;
root.prev = nullptr;
while (q--) {
int T;
string S;
cin >> T >> S;
// 今回たどるノードは根からスタート
Node* now = &root;
bool hit = false;
// 文字列 S の各文字を Trie 上でたどる
for (char c : S) {
// 枝がなければ新しくノードを作成し、親ポインタを設定
if (!now->to.count(c)) {
now->to[c] = new Node();
now->to[c]->prev = now;
}
now = now->to[c];
// banned フラグが立っていたら以降の処理は打ち切る
if (now->banned) {
hit = true;
break;
}
}
if (!hit) {
if (T == 1) {
// T==1:現在のノード以下を禁止する操作
int removed = now->cnt; // これまでのカウント分を取り除く
now->banned = true; // 以降、この枝は hit=true で無視される
// 親ノードに向かって cnt を引いていく
while (now != nullptr) {
now->cnt -= removed;
now = now->prev;
}
}
else if (T == 2) {
// T==2:現在のノード以下をカウントアップする操作
while (now != nullptr) {
now->cnt++;
now = now->prev;
}
}
}
// 最後に根ノードの cnt(全体のカウント)を出力
cout << root.cnt << "\n";
}
return 0;
}
自分より下で一番近いノード
- 末尾を追加したり削除したりして最も近い単語に一致させるためのコストは?
- (言い換え)Trie 木上で最も近いノードとの距離は?
- (言い換え)ある共通接頭辞から、最も近い下流ノードへの距離との合計の最小値は?
#include <iostream>
#include <string>
#include <map>
#include <algorithm>
using namespace std;
// Trie ノードの定義
struct Node {
int cnt = 0; // このノードまで到達した文字列の個数(未使用でも残す)
int nearest_depth = 1000000000; // このノードを末尾とする既存文字列の最短長さ
Node* prev = nullptr; // 親ノードへのポインタ
map<char, Node*> to; // 子ノードへの枝
};
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
// lambda: a = min(a, b) を簡潔に書くためのヘルパー
auto chmin = [](auto &a, auto b) {
if (b < a) a = b;
};
int n;
cin >> n;
// Trie の根ノードを作成し、空文字列長 0 を登録
Node root;
root.nearest_depth = 0;
// 入力された文字列を一つずつ処理
for (int i = 0; i < n; i++) {
string s;
cin >> s;
int m = s.size();
Node* now = &root;
int idx = 0;
// ans に最小コストを蓄える
// 初期値は十分大きく
int ans = 1000000000;
// 文字列 s を一文字ずつ Trie 上でたどりながらコストを計算
while (idx < m) {
char c = s[idx];
// 枝がなければ新規ノードを作成し、親へのポインタを設定
if (!now->to.count(c)) {
now->to[c] = new Node();
now->to[c]->prev = now;
}
// 「ここまで idx 文字をマッチさせて残りは挿入する」コストを計算
// m:文字列全長
// now->nearest_depth:このノード以降にすでにある文字列の長さ(最短)
// 2*idx:繰り返し位置を考慮した調整項
int cand = m + now->nearest_depth - 2 * idx;
chmin(ans, cand);
// 次の文字へ
now = now->to[c];
idx++;
}
// 文字列を全部たどり終えたあともう一度評価
{
int cand = m + now->nearest_depth - 2 * idx;
chmin(ans, cand);
}
// 今回の文字列について最小コストを出力
cout << ans << "\n";
// 後処理:たどってきたノードすべてに対して
// nearest_depth を今回の文字列長 m で更新(最短距離の登録)
while (now != nullptr) {
chmin(now->nearest_depth, m);
now = now->prev;
}
}
return 0;
}
やってみよう、Trie 木!
今回、一番伝えたかったことは、
struct Node {
int cnt = 0;
map<char, Node*> to;
};
これで(なんなら後半の一行だけでも)トライ木を実装できるということです。もしデバッグや改造をしたい人がいたら、記事中のコードを参考にしてみてください。