競プロ典型 90 問005 - Restricted Digitsの解法メモ
解法のポイント
1) 「数そのもの」じゃなくて「Bで割った余り」だけ見ればいい
N桁の数を作る過程で、次の桁 d を付けたときの更新が
new = (10 * r + d) * mod B
と 余りだけで閉じるので、「状態は 0〜B-1 のB通りで十分」と気づける。
2) Nが巨大だから「1桁ずつ」ではなく「連結して倍々に伸ばす」
普通の桁DPは N回 更新しますが、この問題は N ≤ 10^18 なので不可能。
そこで、
- 長さaのブロックと長さbのブロックを連結すると余りが
(i * 10^b + j) * mod Bになる
という性質を使って、「長さを 1→2→4→8… と二乗で作る」=繰り返し二乗法に持ち込む。
1) 何を数えているか(分布DP)
長さ L のとき
dp[r] =(使える数字で作った L 桁の数のうち、B で割った余りが r になる個数)
という 余りごとの個数分布を持ちます。
2) combine がやっていること
A を「左ブロック(長さ a)」、D を「右ブロック(長さ b)」とすると、
- 左の余りが
i - 右の余りが
j
のとき、連結後の余りは
r = (i * 10^b + j) % B
なので、個数は掛け算で
Cnew[r] += A[i] * D[j]
になります。ここで t = 10^b % B を引数でもらって r = (i*t + j) % B にしてます。
つまり combine(A, D, t) は「長さ a の分布」と「長さ b の分布」から「長さ a+b の分布」を作る関数です。
3) base / res の意味
-
base:いま持っているブロックの分布
最初は長さ1なので、base[d%B]++(1桁はその数字そのもの) -
res:答え側(これまでに確定して連結した部分)の分布
最初は「何も選んでない」長さ0が1通りなのでres[0]=1
4) pow10 の意味と更新
-
pow10は「いまの base の桁数 len に対して 10^len % B」 -
baseをbase ⊗ baseにすると、桁数が len → 2len になります
だから次に必要なのは10^(2len) % Bで、これは
10^(2len) = (10^len)^2より
pow10 = pow10 * pow10 % Bと更新できます。
5) while(N>0) の流れ(二進法でN桁を作る)
N を2進数で見て、
- ビットが1なら
res = res ⊗ base(その長さのブロックを答えに採用) - 毎回
base = base ⊗ base(ブロック長を倍にする) -
N >>= 1で次のビットへ
これで「必要な 1,2,4,8,… 桁ブロックだけを連結」して ちょうどN桁を作れます。
最後に res[0] が「Bで割り切れる(余り0)」の個数なので、それを出力して終了です。
Cnew[r] += A[i] * D[j]の補足
まず A[i], D[j] が何を表してるか
-
A[i]= 「左ブロック(長さa)で、余りが i になる作り方の数」 -
D[j]= 「右ブロック(長さb)で、余りが j になる作り方の数」
です。
連結すると「掛け算」になる理由
左側で「余り i になる具体的な数」は A[i] 個あります。
右側で「余り j になる具体的な数」は D[j] 個あります。
このとき、連結してできる数は
- 左の選び方(A[i]通り) × 右の選び方(D[j]通り)
なので、組み合わせ総数は A[i] * D[j] 通りになります(積の法則)。
そして、左の余りが i・右の余りが j の組は全部、同じ余り
r = (i * 10^b + j) % B
に行くので、その余りのカウント Cnew[r] に足します:
Cnew[r] += A[i] * D[j]
小さい例で確認してみる
仮に「左の余り2になる数が3個」(A[2]=3)、「右の余り5になる数が4個」(D[5]=4) なら、
余り(2,5)の組み合わせは
- 左3通り × 右4通り = 12通り
この12通りは全部、同じ r に落ちるので Cnew[r] に 12 足す、ということです。
解答
#include <bits/stdc++.h>
using namespace std;
long long MOD = 1e9 + 7;
vector<long long> combine(const vector<long long>& A,
const vector<long long>& D, long long t, int modB) {
vector<long long> Cnew(modB, 0);
for (int i = 0; i < modB; i++) {
for (int j = 0; j < modB; j++) {
long long r = (i * t + j) % modB;
Cnew[r] = (Cnew[r] + A[i] * D[j]) % MOD;
}
}
return Cnew;
}
int main() {
long long N, B, K;
cin >> N >> B >> K;
vector<long long> cnt(B, 0);
vector<long long> C(K, 0);
for (int i = 0; i < K; i++) {
cin >> C[i];
}
vector<long long> base(B, 0);
vector<long long> res(B, 0);
for (int d : C) {
base[d % B]++;
}
res[0] = 1;
long long pow10 = 10 % B;
while (N > 0) {
if (N & 1) {
res = combine(res, base, pow10, B);
}
base = combine(base, base, pow10, B);
pow10 = (pow10 * pow10) % B;
N >>= 1;
}
cout << res[0] << '\n';
return 0;
}