ABC440C - Striped Horse
https://atcoder.jp/contests/abc440/tasks/abc440_c
ちょっと悩んだ問題ですが、最終的には解けました。考察の過程を書いてみます。
考察
問題文の意味がわからなかったので入力例のテストケース1つめの数字を入れながらなぞってみました。
N, W = 8, 2
C = 1, 10, 10, 1, 1, 10, 10, 1
$ 1 \leq i \leq 8 $ なる i のうち (i + x) ÷ 4 が W 未満になるものだけを黒く塗ります。x は自由に決めてよいです。例えば x = 4 の場合
| i | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
|---|---|---|---|---|---|---|---|---|
| i+x | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 |
| mod4 | 1 | 2 | 3 | 0 | 1 | 2 | 3 | 0 |
| 黒塗り | ○ | ○ | ○ | ○ |
書いてみるとだんだんわかってきますね。x を 1 つずつ大きくしていくと、4 回周期で同じ結果がやってきます。
| x | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
|---|---|---|---|---|---|---|---|---|
| 1 | 2 | 3 | 0 | 1 | 2 | 3 | 0 | 1 |
| 2 | 3 | 0 | 1 | 2 | 3 | 0 | 1 | 2 |
| 3 | 0 | 1 | 2 | 3 | 0 | 1 | 2 | 3 |
| 4 | 1 | 2 | 3 | 0 | 1 | 2 | 3 | 0 |
今の場合はこの表の mod 4 の値が 0 または 1 のときに黒く塗ります。x を 1 大きくするたびに塗られる位置が 1 つずつ動いていきますね。これを繰り返すと次の図のようになります。
こんな風に 2W 周期で1周してくるわけです。長さ N の配列を見て、連続で W 個塗ってまた W 個空けての繰り返しです。2W 通りのパターンをそれぞれ計算し、コストが最小になるものを取れば答えが出ます。
問題はこれをどうやって計算するかです。一つずつ計算してたら間に合わなさそうなのでなんとか差分更新で片付けたいところですが、差分更新するにも $ N \div W $ 箇所の更新をやる羽目になるので N, W の大きさ次第では TLE になりそうです。
ここで C の添字を 2W で割った余りごとに管理することを思いつきます。C[0] が塗られるとき、必ず C[2W] も塗られます。C[1] が塗られるとき、必ず C[2W+1] も塗られます。この性質を利用して、C の添字を i として、i を 2W で割った余りごとに C[i] をそれぞれ足し合わせておきます。入力例1の最初のテストケースならこうなります。
C = 1, 10, 10, 1, 1, 10, 10, 1
に対して添字 i を 4 で割った余りごとにまとめると
S = [2, 20, 20, 2]
となります。あとはこの長さ 2W の配列から連続する W 個を取り出して合計を計算していけばいいです。
実装
添字のズレに注意します。紙に書いて丁寧に確認しながらコードを書きました。
- 最初は 0~W-1 を黒く塗る。
- 次は 1~W を黒く塗る。つまり out S[0], in S[W]
- 次は 2~W+1 を黒く塗る。つまり out S[1], in S[W+1]
- ……
- 次は 0~W-3と2W-2~2W-1を黒く塗る。
- 最後は 0~W-2と2W-1を黒く塗る。つまり out S[2W-2], in S[3W-2] つまり S[W-2]
out S[i] したら、in S[(i+W) % 2W] ですね。
T = int(input())
ans = []
for _ in range(T):
N, W = map(int, input().split())
C = list(map(int, input().split()))
tmp_ans = 10**20
# 添字 i を見て、あまり 0 ~ 2W-1 までに分類して、それぞれの合計値を出す。
S = [0 for _ in range(2*W)]
for i, c in enumerate(C):
num = i % (2*W)
S[num] += c
# 長さ 2W の配列 S の中で連続 W 個の和を求める。その中で最小のものを見つける。
tmp_sum = 0
# 最初は 0 から W-1
for i in range(W):
tmp_sum += S[i]
tmp_ans = min(tmp_ans, tmp_sum)
for i in range(2*W - 1): # 2W-1回ずらしていく
out_S = i
in_S = (i + W) % (2*W)
tmp_sum -= S[out_S]
tmp_sum += S[in_S]
tmp_ans = min(tmp_ans, tmp_sum)
ans.append(tmp_ans)
for an in ans:
print(an)

