添え字ガチャはお得意ですか?
私は苦手です。得意な方はこの記事のありがたみがあまりないかもしれません 1。
この記事は?
添え字ガチャを回避するためのテクを紹介します。
実装・考え方の工夫というぐらいなので、あえてテクというほどでもないかもしれません。
問題設定
- $n$ に対して線形な処理は簡単に書ける
- 大部分は工夫すると連続する部分を一気に処理できる
- 切り替わりの部分や端の処理がめんどくさい、ミスをしやすい
「シミュレーション」の問題(つまり愚直にシミュレーションをすると間に合わないので高速化が求められる問題)も多いです。
テク
この記事のメインの部分です。時短で読みたい人はここだけ読んで理解すればおけです。
- まず愚直解法を書く( $O(n)$ でも気にしない)
- 安全地帯(端とか切れ目と無縁な場所)を安全な方法で高速化する
- 怪しい部分は周辺も含めて愚直のままにする
簡単ですね。 2. の部分がポイントです。添え字をきれいに合わせる必要はなくて、安全に高速化できる部分を雑に高速化するイメージです。少しでも怪しいかもと思ったら愚直にすれば良いです。全体を通して愚直部分があまり多くなければ TLE せず通すことができます。
例(ネタバレ注意)
例1
例として この問題 を考えてみましょう。
とりあえず愚直を書きます。
def calc(l, r):
i = 1
while 1:
if l >= r:
if i <= l:
l -= i
else:
break
else:
if i <= r:
r -= i
else:
break
i += 1
return (i-1, l, r)
T = int(input())
for _ in range(T):
l, r = map(int, input().split())
print("Case #" + str(_ + 1) + ":", *calc(l, r))
簡単ですね、問題の設定をそのまま書いただけです。ただこれだと small は通るけど large は落ちます。
これを高速化することを考えましょう。まあよく考えれば平方根演算とかを使って $O(1)$ で計算できそうな雰囲気もありますが、添え字バグを起こしまくりそうな予感がぷんぷんします。
def calc(l, r):
# 一気に進める関数
def subcalc(d, i):
a, b = 0, 10 ** 9
while b - a > 1:
k = (a + b) // 2
if (2 * i + k - 1) * k // 2 <= d:
a = k
else:
b = k
return (a, (2 * i + a - 1) * a // 2)
i = 1
while 1:
##### ##### 高速化ここから ##### #####
if l - r > i * 10: # 差が小さいとコーナーがありそうなので雑に余裕を持って
k, m = subcalc(l - r, i)
# k ステップ進めます(l がその分減ります)
i += k
l -= m
elif r - l > i * 10: # 差が小さいとコーナーがありそうなので雑に余裕を持って
k, m = subcalc(r - l, i)
# k ステップ進めます(r がその分減ります)
i += k
r -= m
if i >= 5 and l > r > l - i > 0: # ここも適当に進めます
k = l - r
while not (l - (i + k - 1) * k > 0 and r - (i + k) * k > 0):
k //= 2
if k == 0: break
else:
l -= (i + k - 1) * k
r -= (i + k) * k
i += k * 2
##### ##### 高速化ここまで ##### #####
if l >= r:
if i <= l:
l -= i
else:
break
else:
if i <= r:
r -= i
else:
break
i += 1
return (i-1, l, r)
T = int(input())
for _ in range(T):
l, r = map(int, input().split())
print("Case #" + str(_ + 1) + ":", *calc(l, r))
コード中のコメントからも分かるとおり、「高速化部分」は一気に $k$ ステップ進めるみたいな処理をやっていますが、この部分をまるまるコードから消しても( TLE を気にしなければ)問題なく動くコードになっています。
"subcalc" という関数で、一気に何ステップ進めて良いか計算していますが、ここはちょっと少な目になっても何も問題ないです。もし $1$ ~ $2$ ステップ足りない場合は愚直がその分増えるだけですし、最悪半分ぐらいしか進んでなかったとしてもまた "subcalc" が呼ばれて一気に進むのが繰り返されるので大丈夫です。
つまり「安全地帯」を抜け出してしまわないことだけを注意すれば、一気に進むステップ数はだいぶ雑に設定しても問題ありません。これによって実質的に添え字ガチャをせずに解くことができます 6。
例2
次は この問題 を見てみましょう。
これ初めてのチーム戦で緊張していたんですが、解法はこんな感じです 7。
N, K = map(int, input().split())
X, Y = map(int, input().split())
C = [0] * 101010
t = 0
for a in map(int, input().split()):
C[a] += 1
t += a
ans = 10 ** 100
cx = 0
for m in range(10 ** 5 + 5, K, -1):
while C[m]:
cy = max(m if C[m] else m - 1, (t + K - 1) // K)
ans = min(ans, cx * X + cy * Y)
cx += 1
C[m] -= 1
C[m-K] += 1
t -= K
cy = max(m if C[m] else m - 1, (t + K - 1) // K)
ans = min(ans, cx * X + cy * Y)
for m in range(K, -1, -1):
cy = max(m if C[m] else m - 1, (t + K - 1) // K)
ans = min(ans, cx * X + cy * Y)
while C[m]:
cx += 1
C[m] -= 1
t -= m
cy = max(m if C[m] else m - 1, (t + K - 1) // K)
ans = min(ans, cx * X + cy * Y)
print(ans)
高速化の際は、全体のコードは変えずに、余裕を持って高速化できる範囲で一気にステップを進めます。
N, K = map(int, input().split())
X, Y = map(int, input().split())
C = [0] * 101010
t = 0
for a in map(int, input().split()):
C[a] += 1
t += a
ans = 10 ** 100
cx = 0
for m in range(10 ** 5 + 5, K, -1):
while C[m]:
##### ##### 雑な高速化 ここから ##### #####
if m - 1 > (t + K - 1) // K and C[m] >= 3: # このへんはだいぶ雑に(余裕を持って)
cy = max(m if C[m] else m - 1, (t + K - 1) // K)
ans = min(ans, cx * X + cy * Y)
d = C[m] - 2
cx += d
C[m] -= d
C[m-K] += d
t -= K * d
dd = (t + K - 1) // K - (m - 1)
if dd >= 5 and C[m] >= 3: # このへんはだいぶ雑に(余裕を持って)
cy = max(m if C[m] else m - 1, (t + K - 1) // K)
ans = min(ans, cx * X + cy * Y)
d = min(dd - 3, C[m] - 2)
cx += d
C[m] -= d
C[m-K] += d
t -= K * d
##### ##### 雑な高速化 ここまで ##### #####
cy = max(m if C[m] else m - 1, (t + K - 1) // K)
ans = min(ans, cx * X + cy * Y)
cx += 1
C[m] -= 1
C[m-K] += 1
t -= K
cy = max(m if C[m] else m - 1, (t + K - 1) // K)
ans = min(ans, cx * X + cy * Y)
for m in range(K, -1, -1):
cy = max(m if C[m] else m - 1, (t + K - 1) // K)
ans = min(ans, cx * X + cy * Y)
while C[m]:
cx += 1
C[m] -= 1
t -= m
cy = max(m if C[m] else m - 1, (t + K - 1) // K)
ans = min(ans, cx * X + cy * Y)
print(ans)
AC コード 8 でも「雑な高速化」部分を挿入する感じで書いていますね。つまりこの部分がなくても( TLE を考えなければ)問題なく通るコードになっています。
添え字は少し余裕を持って設定しています。多少ずれても「安全地帯」から外れさえしなければ問題なく動くようになっているので、細かい添え字ガチャをする必要がありません。
例3
最後に 最近あったこの問題 も例として挙げてみます。上ふたつの例ではシミュレーションを一気に進めるという感じでしたが、この問題は少し趣向が違って $N$ 個のうち条件を満たすのはいくつかを求める問題です。
結論を言ってしまうと、レベルを固定すれば、「安全地帯」の範囲内では条件を満たす $n$ の個数が等差数列になります。この問題では「安全地帯」から外れるのは、全体の端っこと区間の左右が入れ替わるときです。
具体的には、レベル $k$ を固定すると
- $ (k-a_1)b_1 \le i \lt (k-a_1+1)b_1 \quad\cdots\quad ① \quad$
- $ (k-a_2)b_2 \le i \lt (k-a_2+1)b_2 \quad\cdots\quad ② \quad$
を両方満たす範囲が条件を満たします。 $k$ を動かしたときこの $2$ つの区間の位置関係は何度か変わるかもしれません。その変わる瞬間(例えば ① の右辺と ② の左辺が近くなるなどの場合)は「危険」ですが、それ以外は「安全」です。安全地帯の範囲では条件を満たすやつの個数は等差数列で動くので、まとめて計算することができます。コードはこんな感じです。
def calc(n, a1, b1, a2, b2):
s = max(a1, a2)
t = min(a1 + n // b1, a2 + n // b2) + 1
if s >= t: return 0
# 危険そうな位置を全部リストアップ
S = [s, t]
if b1 != b2:
x = (a1 * b1 - a2 * b2) // (b1 - b2)
if s <= x <= t:
S.append(x)
x = ((a1 - 1) * b1 - (a2 - 1) * b2) // (b1 - b2)
if s <= x <= t:
S.append(x)
x = (a2 * b2 - a1 * b1 + b1) // (b2 - b1)
if s <= x <= t:
S.append(x)
x = (a1 * b1 - a2 * b2 + b2) // (b1 - b2)
if s <= x <= t:
S.append(x)
T = {s, t}
for a in S:
for d in range(-2, 3): # 危険地帯の周辺も全部突っ込む
b = a + d
if s <= b <= t:
T.add(b)
X = sorted(T) # 危険地帯を表す配列。その間は等差数列になるはず
re = 0
for x, y in zip(X, X[1:]):
k = x
l = max((k - a1) * b1, (k - a2) * b2, 1)
r = min((k - a1 + 1) * b1, (k - a2 + 1) * b2, n + 1)
z1 = max(r - l, 0) # 区間の左端のこたえ
k = y - 1
l = max((k - a1) * b1, (k - a2) * b2, 1)
r = min((k - a1 + 1) * b1, (k - a2 + 1) * b2, n + 1)
z2 = max(r - l, 0) # 区間の右端のこたえ
re += (z1 + z2) * (y - x) // 2 # 等差数列の性質を利用して一気に合計(危険地帯ではひとつずつ合計している)
return re
T = int(input())
for _ in range(T):
n, a1, b1, a2, b2 = map(int, input().split())
print(calc(n, a1, b1, a2, b2))
AC コード では、危険そうな場所(とその周辺)を適当に突っ込んで、危険地帯では $1$ つずつ、それ以外では等差数列の公式を使って一気に処理するようにしています。
この解法では危険地帯として何を突っ込むかがポイントになりそうですが、突っ込む位置の添え字が $1$ とか $2$ とかずれていても(どのみちその周りを全部突っ込むので)問題ありませんし、実は危険じゃない箇所を危険地帯として入れてしまっても問題ありません 9。いずれにせよ添え字ガチャからは逃げることができました。
例4
(2021/7/29 追記)
某勉強会の宿題になっていた こちらの問題 でも使えました。
まずは愚直に書いてみます。 $N$ が小さいので計算量オーダーを落とす努力はしていません。
def calc():
a = B[0]
j = 0
s = 0
while j < 37 and B[j] == a:
s += B[j] - A[j]
j += 1
return s * 36 / j - (sum(B) - sum(A))
T = int(input())
for case in range(T):
X, N = map(int, input().split())
A = [0] * (37 - N) + sorted([int(a) for a in input().split()])
B = A[:]
ans = 0
c = 0
while c < X:
j = 0
while j < 37 and B[j] == B[0]:
j += 1
for jj in range(j)[::-1]:
if c >= X: continue
B[jj] += 1
c += 1
ans = max(ans, calc())
print("Case #" + str(case + 1) + ":", "{:.9f}".format(ans))
上ので small は通りますが、 hard は 1 兆円ぐらいあるのでこれだと落ちちゃいます。
ちょっと考察すればどういう場合に最適化されるかなどが分かりそうですが、ここでは 考察もしない 方向で「安全地帯高速化」のみで通したいと思います。
やることは簡単で、状況が変わらない範囲ではたくさん進めても問題ないので、雑に進めるコードを入れてみます。この例でも、元のコードはいじらず、高速化コードを挿入することで large でも通るコードになりました。
def calc():
a = B[0]
j = 0
s = 0
while j < 37 and B[j] == a:
s += B[j] - A[j]
j += 1
return s * 36 / j - (sum(B) - sum(A))
T = int(input())
for case in range(T):
X, N = map(int, input().split())
A = [0] * (37 - N) + sorted([int(a) for a in input().split()])
B = A[:]
ans = 0
c = 0
while c < X:
j = 0
while j < 37 and B[j] == B[0]:
j += 1
for jj in range(j)[::-1]:
if c >= X: continue
B[jj] += 1
c += 1
ans = max(ans, calc())
##### 高速化 ここから #####
if j >= 37: break
m = max(min(B[j] - B[j-1], (X - c) // j) - 5, 0)
if m:
for jj in range(j):
B[jj] += m
c += m
##### 高速化 ここまで #####
print("Case #" + str(case + 1) + ":", "{:.9f}".format(ans))
例5
(2022/11/27 追記)
本記事のテクとは少し違いますが、候補を広めに追加するという意味で ABCのこちらの問題 も紹介します。
from math import sqrt
A, B = map(int, input().split())
x = int((A / (2 * B)) ** (2 / 3)) # 相加相乗平均(微分でも)で候補を絞る
L = [0] + [x + i for i in range(-10, 11)]
# ↑ ここで候補とその周辺を広めに突っ込む
# 一回も操作をしない場合のゼロも忘れずに
ans = 10 ** 100
for a in L:
if a >= 0: # 広めに突っ込んだので、間違って負の数を使わないように注意
t = a * B + A / sqrt(1 + a)
ans = min(ans, t)
print(ans)
まとめ
「安全地帯」と「危険地帯」に分けてみたとき、高速化できるのは「安全地帯」の部分です。安全な範囲では一気に処理を進めることで高速化できます。
ここで安全地帯と危険地帯の境目を丁寧に分離するコードを書こうとする必要はありません。「危険地帯」は余裕を持って広めに取ることで、添え字ガチャから逃げることができます。つまり、範囲を多少間違えても通るようなコードにすることができます。
ところで
このテクを使うと、どうしても「余裕を持った」設定になることが多いので、実行時間とかコード長とかには無駄が残りがちです。なので完璧主義の方にはあまり好まれない可能性がありますが、コンテスト中にバグらせずに通すということを第一に考えると、本記事の考え方が使えることが多いかもしれません。
おわり
おわり
-
苦手な方に必ずありがたみがあるとは言ってません ↩
-
わりとピンと来ないぐらい範囲が広いですね。それぐらい一般的に使える可能性があるということでご理解ください。
特に次のような場合に使えます 2。 $n$ はとても大きい(線形な処理では間に合わない)とします。 ↩ -
説明の都合でコメントを入れる程度の改変はしています
ここまでだと何言ってるか分からないという人も多いかもしれないのでいくつか例を挙げます。信憑性(?)を持たせるために、コードは私が実際にコンテストで使ったものをほぼそのまま紹介します 3。 ↩ -
$k$ は適当なステップ数
そこで、「安全地帯」(この問題では $L$ と $R$ の大小が入れ替わらない、かつパンケーキがなくならない)の範囲では、一気に $k$ ステップ 4 進めることを考えます。 $k$ ステップ進めるのは平方根もいらないので比較的簡単にできますね。 ↩ -
おかげで私も通すことができました ↩
-
この記事を書いているときは解法は覚えていないので、画像を貼ることで説明したことにします。言いたいことは高速化部分は添え字ガチャしなくてもできるよという話なので、解法自体はこの記事ではあまり重要ではないです。
↩
-
ループ回数がほんの少し多くなるかもしれないですが ↩