ABC426 C - Upgrade Required
https://atcoder.jp/contests/abc426/tasks/abc426_c
多くの人が解けているのにさっぱりわかりませんでした……ですが復習することで大事なことを学べました。前にも同じ問題に引っかかっているので、過去の学習を活かせなかったということでもありますが。
コンテスト中の動き
ABを3分で片付けC問題に取りかかりました。(たぶんあかんやろな)というコードを書いて案の定TLE、この時点でコンテスト開始から28分です。すぐには解けなさそうな空気を感じ取りつつ思い切ってD問題に行きました。こちらはなぜかすぐに解法が閃き11分で解けました。
Cに戻ります。悩んで悩んで別の解法を試してみましたがそちらもTLEでした。むしろTLEの数が増えてしまっており、どうしようもなかったです。結局30分ほど残してギブアップしました。
最初の提出
入力例を例にとって説明します。「バージョンX以下のマシンが何台あるか」を高速に取得するため、累積和的にそれを記録しておきます。
バージョン | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
---|---|---|---|---|---|---|---|---|
初期 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
(2, 6) | 0 | 0 | 1 | 2 | 3 | 6 | 7 | 8 |
(3, 5) | 0 | 0 | 0 | 1 | 3 | 6 | 7 | 8 |
(1, 7) | 0 | 0 | 0 | 1 | 3 | 6 | 7 | 8 |
(5, 7) | 0 | 0 | 0 | 0 | 0 | 3 | 7 | 8 |
(7, 8) | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 8 |
最初に (X, Y) = (2, 6) が来て、バージョン2以下が2個あるのでバージョン5までの全てから2を引きます。マイナスになる場合は0にします。
次に (X, Y) = (3, 5) が来て、バージョン3以下が1個あるのでバージョン4までの全てから1を引きます。
これを繰り返します。一応気づきとして、一度アップグレードを受けたバージョンのOSは全て0台になるので、この表の数字は必ず右肩上がりになります。また、後から小さなXが入力されてもそれ以上のXが既出であれば全部0台になっているので無視できます。これで一応少しだけの高速化はできますが……(配列の中で範囲内の減算をQ回もするからどうせTLEだろうな)と思って提出したところ、案の定TLEでした。他の解法が思いつかないのでいったん逃げます。
次の提出
順位表の正解者数を見たところ、このC問題を5000人以上の人が正答していました。ということはもっともっと簡単な何かで解けるはずです。一体何なのか?なぜ気づけないのか?プレッシャーを感じつつ考えました。延々と悩んでようやく出てきたのは「何のバージョンが何個あるのかを辞書型で持っておく」という発想です。30分ほど考察と実装に使い、提出しましたがやっぱりTLEでした。ここで心が折れます。
なんとなく考えていたのは「何かの回数がせいぜい N 回とか Q 回で済んで、一見 TLE になりそうな愚直っぽい解法が実は間に合う」というものでした。過去問でもそういうものを見たことがあります。今回はそこまで思いついていたのであとは計算量をちゃんと見積もれればよかったのですが……あくまで「なんとなく」であり冷静にはなりきれませんでした。
提出コードはこれです。貪欲にバージョンが小さなものから順に取り出していって処理します。ただし毎回バージョン1からバージョンXまで参照して足していったら計算量が大きくなるのでこれらを辞書型で管理し、0個になったバージョンは辞書型配列(下のコードでいうSD)から削除していきます。これで多少スキップできますね。
コードの中、特に内側の for ループがくどいですが、これはループ中に辞書型配列 SD の中身を削除するとおかしなことになってしまうので、削除したいもの一覧を持っておき後からあらためて削除するようにしています。
結局TLEです。
# 何が何台かをdictで管理
from sortedcontainers import SortedDict
N, Q = map(int, input().split())
SD = SortedDict([])
ans = []
for i in range(1, N+1):
SD[i] = 1
for i in range(Q):
X, Y = map(int, input().split())
tmp_ans = 0
to_deleted = set()
for k, v in SD.items():
if k <= X:
tmp_ans += v
to_deleted.add(k)
else:
break
if tmp_ans > 0:
SD[Y] += tmp_ans
ans.append(tmp_ans)
for d in to_deleted:
del SD[d]
for an in ans:
print(an)
ACできた提出
解説も読まないまま翌日になり、heapq とか平衡二分木という話を Twitterで見かけました。そこで、もしやと思い上記の SortedDict を heapq を用いたコードに書き換えました。するとあっさりACできました。
SortedListが重いという話は最近しているんですよね。それなのにそこに思い至ることができませんでした。これは情けないです。
https://qiita.com/omakasessan/items/c14a6da11928f52286a4
- 配列を見て、常にバージョンが最小のものから順に取り出していく
- 新たにどのバージョンのOSが何台増えるのかを配列に加えていく
これだけできればいいので heapq で十分なんですね。バージョンはN通りしかないので取り出す回数は最大N回です。加える回数はクエリごとに1回なので最大Q回です。となると計算量は最大 O((N+Q)logN) でいけそうですね。
import heapq
N, Q = map(int, input().split())
hq = [] # (バージョン番号、台数) の tuple を持たせる
heapq.heapify(hq)
ans = []
for i in range(1, N+1):
heapq.heappush(hq, (i, 1))
for _ in range(Q):
X, Y = map(int, input().split())
tmp_ans = 0
while (hq[0][0] <= X): # 目標とするバージョン以下のものが取り出せる限り続く
h = heapq.heappop(hq) # バージョンが最小のものを取り出して処理する
tmp_ans += h[1]
if tmp_ans > 0:
heapq.heappush(hq, (Y, tmp_ans))
ans.append(tmp_ans)
for an in ans:
print(an)
まとめ
SortedDict ではなく heapq を使うのがこの問題のポイントでした。また、そのことに気づくためには計算量の正しい見積もりが必要だったように思います。正確な計算量がわかっていれば、「それでも間に合わないなら定数倍遅くなったせいでTLEになってしまっているかもしれない」と予測できたかもしれません。計算量が「なんとなくこうしたらいけそう(無理かもしれない)」ぐらいの気持ちでしか予測できていなかったせいで、どこを疑えばいいのかわからなかったんですね。
また heapq の使い方もまだあまりわかっておらず、最初にACできたコード(この記事には貼っていません)はかなり冗長でした。レートや経験の割にあまりにお粗末な状況です。今後はなるべく heapq で済ませられないか考えながら解くように心がけます。