はじめに
こんにちは、 ayataka です。
F - Manhattan Cafe の解説です。
Sample Codeは以下の構文を含みます。
#include <bits/stdc++.h>
using namespace std
苦手な方はそっとブラウザバックしてください。
dpの設計
$1 ≤ N ≤ 100$, $0 ≤ D ≤ 1000$ と、比較的小さいため、dp をしてみたくなります。具体的には以下のような dp を実装します。
dp[i][j][k] := i次元目まで考えた時に、pからjの距離,qからkの距離にある格子点の個数
さあ、遷移を考えましょう。まずは、全探索してみましょう。
Sample Code
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const ll mod = 998244353;
int main() {
int N, D;
cin >> N >> D;
vector<int> p(N), q(N);
for(int i = 0; i < N; i++) cin >> p[i];
for(int i = 0; i < N; i++) cin >> q[i];
vector<vector<vector<ll>>> dp(N+1, vector<vector<ll>>(D+1, vector<ll>(D+1)));
dp[0][0][0] = 1LL;
for(int i = 0; i < N; i++) {
for(int j = 0; j <= D; j++) {
for(int k = 0; k <= D; k++) {
ll val = 0LL;
/* i 次元目での r の位置を全探索 */
for(int r = -2000; r <= 2000; r++) {
int distp = abs(p[i] - r);
int distq = abs(q[i] - r);
if(distp <= j && distq <= k) {
val += dp[i][j-distp][k-distq];
val %= mod;
}
}
dp[i+1][j][k] = val;
}
}
}
ll ans = 0LL;
for(int j = 0; j <= D; j++) {
for(int k = 0; k <= D; k++) {
ans += dp[N][j][k];
ans %= mod;
}
}
cout << ans << endl;
return 0;
}
このコードの計算量は $O(ND^3)$ です。これではTLEしてしまうので、高速化したいです。
dp-tableの観察
dp-tableをじっと睨んで効率できないか考えましょう。以下、入力例 1が与えられたときの処理を考えます。
$dp[1][5][5]$の更新の時にもらってくるマスを考えます。先ほどのコードに次のように追加すると出力は以下のようになります。
-2 3 0
-1 4 1
0 5 2
1 4 3
2 3 4
3 2 5
4 1 4
5 0 3
0 | 1 | 2 | 3 | 4 | 5 | |
---|---|---|---|---|---|---|
0 | . | . | . | # | . | . |
1 | . | . | . | . | # | . |
2 | . | . | . | . | . | # |
3 | # | . | . | . | # | . |
4 | . | # | . | # | . | . |
5 | . | . | # | . | . | ! |
Debug Code
ll val = 0LL;
/* i 次元目での r の位置を全探索 */
for(int r = -2000; r <= 2000; r++) {
int distp = abs(p[i] - r);
int distq = abs(q[i] - r);
if(distp <= j && distq <= k) {
val += dp[i][j-distp][k-distq];
if(j == 5 && k == 5) cerr << r << ' ' << j-distp << ' ' << k-distq << endl;
val %= mod;
}
}
dp[i+1][j][k] = val;
上の表を見ていただければわかると思いますが、遷移には斜めの規則性があります。なぜ、斜め方向に遷移するのでしょうか?$i$次元目でのを遷移を考えます。
x | x | p[i] | x | x | q[i] | x | x | |
---|---|---|---|---|---|---|---|---|
distp | 2 | 1 | 0 | 1 | 2 | 3 | 4 | 5 |
distq | 5 | 4 | 3 | 2 | 1 | 0 | 1 | 2 |
[- | -] | [- | - | - | -] | [- | -] |
左から、$X$, $Y$, $Z$部分とすると、$X$, $Z$部分は$distp$と$distq$がともに増加していますが$Y$部分は、$distp$と$distq$が増加、減少をとっています。先ほどの図を用いると、以下のようになります。
0 | 1 | 2 | 3 | 4 | 5 | |
---|---|---|---|---|---|---|
0 | . | . | . | Z | . | . |
1 | . | . | . | . | Z | . |
2 | . | . | . | . | . | Y |
3 | X | . | . | . | Y | . |
4 | . | X | . | Y | . | . |
5 | . | . | Y | . | . | ! |
よって、$w = abs(p[i]-q[i])$とおくと、$dp[i+1][j][k]$は、$X$部分「$(j-1, k-w-1)$から端まで」と$Y$部分「$(j, k-w)$から$(j-w, k)$まで」と$Z$部分「$(j-w-1, k-1)$から端まで」の和であることが分かります。よってdp[i]
の斜め方向の累積和をとればよいです。遷移が$O(1)$で計算することができるようになり、全体の計算量$O(ND^2)$でこの問題を解くことができます。
提出URL
[$25$行目コメント部分、間違っていました。実装例では修正してあります。]
実装例
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using ld = long double;
const ll mod = 998244353;
const ll INF = (1LL << 60);
int main() {
int N, D;
cin >> N >> D;
vector<int> p(N), q(N);
for(int i = 0; i < N; i++) cin >> p[i];
for(int i = 0; i < N; i++) cin >> q[i];
vector<vector<ll>> dp(D+1, vector<ll>(D+1));
dp[0][0] = 1LL;
for(int i = 0; i < N; i++) {
/* s1にi番目のdpを入れておいて、s1からi+1番目のdpに遷移させる「貰う」dp */
vector<vector<ll>> s1(D+1, vector<ll>(D+1));
swap(s1, dp);
vector<vector<ll>> s2 = s1;
/* 累積和 */
/* X, Z部分 */
for(int j = 0; j < D; j++) {
for(int k = 0; k < D; k++) {
s1[j+1][k+1] += s1[j][k];
s1[j+1][k+1] %= mod;
}
}
/* Y部分 */
for(int j = 0; j < D; j++) {
for(int k = 0; k < D; k++) {
s2[j+1][k] += s2[j][k+1];
s2[j+1][k] %= mod;
}
}
int w = abs(p[i] - q[i]);
auto sum1 = [&](int j, int k) {
/* 外にある時は 0 */
if(j < 0 || k < 0) return 0LL;
return s1[j][k];
};
auto sum2 = [&](int j, int k, int w) {
int nj = j -w-1, nk = k +w+1;
/* 外にある時は修正 */
if(k < 0) {
j += k;
k = 0;
}
if(j < 0) return 0LL;
ll res = s2[j][k];
if(0 <= nj && nk <= D) {
/* いらない部分は引いておく */
res -= s2[nj][nk];
}
return res;
};
for(int j = 0; j <= D; j++) {
for(int k = 0; k <= D; k++) {
ll val = 0;
val += sum1(j- 1, k- w-1);
val += sum1(j- w-1, k- 1);
val += sum2(j, k-w, w);
val %= mod;
dp[j][k] = val;
}
}
}
ll ans = 0LL;
for(int i = 0; i <= D; i++) {
for(int j = 0; j <= D; j++) {
ans += dp[i][j];
ans %= mod;
}
}
/* after-contest で WA が出てしまうので追加 */
if(ans < 0) ans += mod;
cout << ans << endl;
return 0;
}
終わりに
お疲れさまでした。実装がちょっと大変ですね。初めてQiitaに投稿し、また数時間で書き上げた記事なので見づらい部分もあるかもしれませんが、暖かい目で見ていただければありがたいです。なお、類題にEDPC-Mがあります。よかったらどうぞ!