この記事について
今回は、Paizaの「文字列収集」という問題を、文字列アルゴリズムの中でも代表的な
- Trie木
- Rolling Hash
を用いて解きます。なお、この記事ではこれらの文字列アルゴリズムそのものに関する説明はせず、それらの説明は他の記事に譲ることとします。
問題概要
まず、今回扱うのは以下の問題です。
$N$ 個の文字列 $S_1,S_2,\dots,S_N$ があり、各文字列には値 $P_i$ が定まっています。以下の形式の $M$ 個の質問すべてについて答えてください。
文字列 $T$ が与えられる。 $T$ から始まるような文字列全てに対する $P_i$ の値を総和はいくら?
制約として
- $1\leq N \leq 10000$
- $1\leq M \leq 10000$
- $1\leq \vert S_i\vert \leq 100 (1 \leq i \leq N)$
- $1\leq P_i \leq 10000 (1 \leq i \leq N)$
- $1\leq \vert Q_i\vert\leq 100 (1 \leq i \leq M)$
が与えられています(なお、文字列の長さは $\vert\ast\vert$ のようにして表すことにします)。
解法1 (Trie木)
Trieのノードに
- このノードに対する $P_i$ の和
- 部分木に含まれるノードに対する $P_i$ の和
を持たせます。 追加のたび Bottom Upに更新するようにすればよいです。
#include <bits/stdc++.h>
using namespace std;
struct Trie {
struct node {
node* par; // 親へのポインタ
node* ch[26]; // 子へのポインタ
long long value = 0; // このnodeの P_i の和
long long sum = 0; // node の部分木全てに対する P_iの和
};
node* root = nullptr;
Trie() : root(new node()) {}
void add(const string& s, int p) {
// Sに対応するノードを検索する(なければ作る)
node* cur = root;
for (int i = 0; i < (int)s.size(); ++i) {
if (cur->ch[s[i] - 'a'] == nullptr) {
cur->ch[s[i] - 'a'] = new node();
cur->ch[s[i] - 'a']->par = cur;
}
cur = cur->ch[s[i] - 'a'];
}
// 上に上りつつ更新する
cur->value += p;
cur->sum += p;
while (cur != root) {
cur = cur->par;
cur->sum = cur->value;
for (int i = 0; i < 26; ++i) {
if (cur->ch[i]) cur->sum += cur->ch[i]->sum;
}
}
}
long long calc_sum(const string& s) {
// Sから始まる文字列に対応するノードを検索する
node* cur = root;
for (int i = 0; i < (int)s.size(); ++i) {
cur = cur->ch[s[i] - 'a'];
if (cur == nullptr) {
// Sから始まる文字列に対応するノードが無ければ0を返す
return 0;
}
}
// 見つかったらそこに書かれてるsumを返す
return cur->sum;
}
};
int main() {
int n, m;
cin >> n >> m;
Trie t;
for (int i = 0; i < n; ++i) {
string s;
int p;
cin >> s >> p;
t.add(s, p);
}
for (int i = 0; i < m; ++i) {
string s;
cin >> s;
cout << t.calc_sum(s) << endl;
}
}
$l=\max {\vert S_i\vert},\sigma=26$ として、計算量は $O(\sigma l)$です。
解法2(Rolling Hash)
実は、質問内容を以下のように改変しても答えは変わりません。
- $S_1$ の先頭 $1$ 文字が $T$ に一致すれば答えに $+P_1$
- $S_1$ の先頭 $2$ 文字が $T$ に一致すれば答えに $+P_1$
- $S_1$ の先頭 $3$ 文字が $T$ に一致すれば答えに $+P_1$
...- $S_1$ の先頭 $\vert S_1\vert$ 文字が $T$ に一致すれば答えに $+P_1$
- $S_2$ の先頭 $1$ 文字が $T$ に一致すれば答えに $+P_2$
- $S_2$ の先頭 $2$ 文字が $T$ に一致すれば答えに $+P_2$
...
...- $S_N$ の先頭 $\vert S_N$ 文字が $T$ に一致すれば答えに $+P_N$
結局答えはいくらか?
このように改変すると、たとえば、文字列をkey, 対応する文字列の個数をvalueとして持つ連想配列を用いることで質問に回答できます。しかし、連想配列のkeyが文字列なので、実は挿入に時間がかかってしまいます。そこでkeyとして文字列のHash値を持つことにより、検索を高速化できます。
計算量は $O(\sum \vert S_i\vert+\sum\vert T_i\vert)$ です。
#include <bits/stdc++.h>
using namespace std;
using u64 = uint64_t;
constexpr u64 mod = 998244353;
constexpr u64 base = 4294966043; // 適当な素数
int main() {
int n, m;
cin >> n >> m;
unordered_map<u64, long long> cnt;
for (int i = 0; i < n; ++i) {
int p;
string s;
cin >> s >> p;
u64 hash = 0;
for (int j = 0; j < (int)s.size(); ++j) {
hash = hash * base + u64(s[j]);
hash %= mod;
cnt[hash] += p;
}
}
for (int i = 0; i < m; ++i) {
string t;
cin >> t;
u64 hash = 0;
for (int j = 0; j < (int)t.size(); ++j) {
hash = hash * base + u64(t[j]);
hash %= mod;
}
cout << cnt[hash] << endl;
}
}
なお、ハッシュ値の衝突が心配であれば、ハッシュのbase
を複数個とるようにすると衝突確率を下げることができます。