2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

競プロテクニック集 - BinaryHeap

Posted at

ここではよくインポートしてそんなに意識せず使いがち?なHeapについて解説しておきます。

BinaryHeapとは

BinaryHeapは内部の要素数を $N$ とした時二分木構造によって新たな要素を追加したり最大ないし最小の値を取り出したりするのを一回あたり $O(\log{N})$ でできるようにしたものです。
内部構造としては各要素に高々二つまでの子を持つ二分木構造となっていて、根に求める要素が置かれた状態になっています。以下最小ヒープで挙動を説明しますが最大ヒープは大小を反転させるだけでできます。

具体的な操作での挙動

以下このような状態になっているものを例に取って説明します。各ノードの値が書かれてます。
image.png

挿入

BinaryHeapに新しく要素を追加する場合まず一番後ろに要素を追加します。その後その要素が根ではなくその親より小さい限り親と交換するというソートを繰り返します。例えば上の例で $2$ を追加するとすると $6$ と書かれた頂点の後ろに追加されます。このとき $2$ は $6$ より小さいので $2$ と $6$ の位置が入れ替わります。同様にしてもう一回上り $1$ の下に来たら $1$ 以上なので停止するといった具合です。そして二分木の構造をしているのでこのソート回数は高々 $O(\log{N})$ 回です。

取り出し

一番小さい要素を取り出す時は根の要素を取り出せばよいです。実装的には根の要素と一番下の葉の要素を交換してから根だった要素を取り出します。例えば最初の例では一番下についている $7$ と $1$ の頂点を交換し $1$ をその後削除します。ただこのままでは根に最小でない要素が残る場合もあるのでその葉の要素で自身より小さいものがある限り下げていきます。図においては $7$ は $2$ と $5$ のうち小さい要素である $2$ より大きいので $2$ と $7$ を交換します。同様にして $7$ は最初の $11$ の上のノードで停止します。このような交換を行うことによって最小のものが常に根にあるようにできます。これも $O(\log{N})$ で行えます。

このような操作によって根に常に最小値が来るようにできました。これは最小値が根以外にあるとするとその親はもっと小さいはずなので矛盾することから背理法でわかります。

挿入と取り出しを同時に行う

さて以下は応用的な操作になります。挿入と取り出しを1回ずつセットで行う時挿入するものが取り出されるものより小さいとき挿入を行う必要がありません。また挿入と取り出しを行うときも関数呼び出しのオーバーヘッドを減らすことができる上、挿入するのも根にそのまま追加要素をおけば良いのでメモリ確保もおきないことから定数倍がよくなります。Pythonのheapqにおいてはheappushpopとして実装されています。

特定要素の変更

すでにheap内に追加した要素に対して操作を加えたくなる場合があります。
(問題の例: https://atcoder.jp/contests/agc023/tasks/agc023_f)
追加した要素はバブルソートを繰り返されることになるのでどこにあるかは元のままではわかりません。なのでheapに追加した要素をキーとしてindexを保持する辞書を別途作成し交換を全て追跡します。こうすることで辞書からindexを特定し変更を行います。そして変更に応じて上下に移動させれば良いです。これも $O(\log{N})$ となりますが定数倍は重めです。
(PythonやRustは辞書、HashMapが重いです。制約的に配列で管理できる場合はそちらに変更した方が良いと思いますしそうでなくとも座標圧縮するのも視野に入ると思います。)

実装例

Pythonによる実装例は以下の通りです。
通常の最小ヒープ
(関数呼び出しが多めになっている都合上実行速度は遅めです。もし速度が欲しいならfloatとsinkをpush,popに直接書いてください。)

最小ヒープ
class BinaryHeap:
    def __init__(self):
        self.tree = []

    def push(self, x):
        p = len(self.tree)
        self.tree.append(x)
        self.float(p)

    def pop(self):
        p = len(self.tree)-1
        self.tree[0], self.tree[p] = self.tree[p], self.tree[0]
        res = self.tree.pop()
        self.sink(0)
        return res

    def pushpop(self, x):
        if self.tree[0] > x:
            return x
        else:
            res = self.tree[0]
            self.tree[0] = x
            self.sink(0)
            return res

    def float(self, p):
        x = self.tree[p]
        while p and self.tree[(p-1)//2] > x:
            par = (p-1)//2
            self.tree[par], self.tree[p] = x, self.tree[par]
            p = par

    def sink(self, p):
        while p*2+1 < len(self.tree):
            nex = 2*p+1
            if nex+1 < len(self.tree) and self.tree[nex] > self.tree[nex+1]:
                nex += 1
            if self.tree[p] > self.tree[nex]:
                self.tree[p], self.tree[nex] = self.tree[nex], self.tree[p]
                p = nex
            else: break

キー付きBinaryHeap
(定数倍がかなり悪いです。またPythonの辞書のハックを防ぐためランダムxorで対策してますが使用感は変わりません。)

キー付きヒープ
class KeyedBinaryHeap:
    # keyにlambda a, b: ~などの比較関数を乗せる
    def __init__(self, key=lambda a, b: a[0] < b[0]):
        self.tree = []
        self.dic = {}
        self.key = key
        from random import randint
        self.xor = randint(100, 1 << 60)

    def push(self, idx, x):
        if idx in self.dic:
            self.update(idx, x)
            return
        p = len(self.tree)
        self.tree.append((x, idx))
        self.dic[idx ^ self.xor] = p
        self._float(p)

    def pop(self):
        p = len(self.tree)-1
        self._swap(0, p)
        res = self.tree.pop()
        self._sink(0)
        del self.dic[res[1] ^ self.xor]
        return res

    def pushpop(self, idx, x):
        if idx ^ self.xor in self.dic:
            self.update(idx, x)
            return self.pop()
        if not self.tree and self.key((idx, x), self.tree[0]):
            return idx, x
        else:
            res = self.tree[0]
            self.tree[0] = (x, idx)
            self.dic[idx ^ self.xor] = 0
            del self.dic[res[1] ^ self.xor]
            self._sink(0)
            return res

    def get(self, idx):
        return self.tree[self.dic[idx ^ self.xor]][0]

    def update(self, idx, x):
        if idx ^ self.xor not in self.dic:
            return
        p = self.dic[idx ^ self.xor]
        self.tree[p] = (x, idx)
        self._float(p)
        p = self.dic[idx ^ self.xor]
        self._sink(p)

    def _float(self, p):
        x = self.tree[p]
        while p and self.key(x, self.tree[(p-1)//2]):
            par = (p-1)//2
            self._swap(p, par)
            p = par

    def _sink(self, p):
        while p*2+1 < len(self.tree):
            nex = 2*p+1
            if nex+1 < len(self.tree) and self.key(self.tree[nex+1], self.tree[nex]):
                nex += 1
            if self.key(self.tree[nex], self.tree[p]):
                self._swap(p, nex)
                p = nex
            else: break

    def _swap(self, x, y):
        _, k1 = self.tree[x]
        _, k2 = self.tree[y]
        self.dic[k1 ^ self.xor], self.dic[k2 ^ self.xor] = y, x
        self.tree[x], self.tree[y] = self.tree[y], self.tree[x]

    def __contains__(self, item):
        return item ^ self.xor in self.dic

使用場面

さてここまで説明してきましたが実際に使う場面はどのような場合でしょうか。それは順序によらない集合における最大ないし最小となる要素を個数分走査することなく取得したい場面になります。特に一回すべて走査してから一番いいものを取る貪欲法と相性がよいです。また最大K個、ダイクストラ法など典型アルゴリズムでも割と出てきます。以下にheapを使って解ける問題を挙げます。ネタバレを含みますのでときたい方は先に問題を解いておいてください。

ABC331 - E - Set Meal

AtCoder 食堂では主菜と副菜からなる定食が販売されています。
主菜は $N$ 種類あり、順に主菜 $1$, 主菜 $2$, …, 主菜 $N$ と呼びます。主菜 $i$ の価格は $a_i$ 円です。
副菜は $M$ 種類あり、順に副菜 $1$, 副菜 $2$, …, 副菜 $M$ と呼びます。副菜 $i$ の価格は $b_i$ 円です。
定食は主菜と副菜を $1$ 種類ずつ選んで構成されます。定食の価格は選んだ主菜の価格と副菜の価格の和です。ただし、$L$ 個の相異なる組 $(c_1, d_1), …, (c_L, d_L)$ について、主菜 $c_i$ と副菜 $d_i$ からなる定食は食べ合わせが悪いため提供されていません。つまり、提供されている定食は $NM−L$ 種類あることになります。(提供されている定食が少なくとも 1 種類存在することが制約によって保証されています。)
提供されている定食のうち、最も価格の高い定食の価格を求めてください。
$1 ≤ N, M ≤ 10^5$
$0 ≤ L ≤ \min(10^5, NM-1)$
$1 ≤ a_i, b_i ≤ 10^9$
$1 ≤ c_i ≤ N$
$1 ≤ d_i ≤ M$

この問題では存在しない組み合わせがある中で総和が最大のものを求める必要があります。制約を見ると $L ≤ 10^5$ となっているのでこれを利用します。つまり定食のうち価格が高いものを列挙していけば $L+1$ 個以内に答えを見つけることができるというわけです。あとはどう列挙するかです。主菜を一つ固定したとき副菜が高い順に定食価格も高くなるので高い順に副菜を走査し、ないならその次の副菜との定食を考えればいいです。よって主菜ごとに何番目の副菜まで見たかをindexで管理すればよいことがわかり価格の高いものを探すにheapを使えばこの問題を $O(L\log{N})$ で解くことができました。実装例では自作クラスを使用していますがもちろんheapqを使用してかまいません。

実装例
class BinaryHeap:
    def __init__(self):
        self.tree = []

    def push(self, x):
        p = len(self.tree)
        self.tree.append(x)
        self.float(p)

    def pop(self):
        p = len(self.tree)-1
        self.tree[0], self.tree[p] = self.tree[p], self.tree[0]
        res = self.tree.pop()
        self.sink(0)
        return res

    def pushpop(self, x):
        if self.tree[0] > x:
            return x
        else:
            res = self.tree[0]
            self.tree[0] = x
            self.sink(0)
            return res

    def float(self, p):
        x = self.tree[p]
        while p and self.tree[(p-1)//2] > x:
            par = (p-1)//2
            self.tree[par], self.tree[p] = x, self.tree[par]
            p = par

    def sink(self, p):
        while p*2+1 < len(self.tree):
            nex = 2*p+1
            if nex+1 < len(self.tree) and self.tree[nex] > self.tree[nex+1]:
                nex += 1
            if self.tree[p] > self.tree[nex]:
                self.tree[p], self.tree[nex] = self.tree[nex], self.tree[p]
                p = nex
            else: break


def main():
    n, m, l = map(int, input().split())
    a = list(map(int, input().split()))
    b = list(map(int, input().split()))
    b = [(b[i], i)for i in range(m)]
    b.sort(reverse=True)
    mx = 100010
    for _ in range(max(mx-m, mx-l)):
        b.append((-1 << 60, mx))
    out = set()
    for _ in range(l):
        u, v = map(lambda x: int(x)-1, input().split())
        out.add((u, v))
    index = [0]*n
    heap = BinaryHeap()
    for i in range(n):
        heap.push((-(a[i]+b[0][0]), i, b[0][1]))
    for _ in range(l+2):
        v, i, j = heap.pop()
        if (i, j) in out:
            index[i] += 1
            heap.push((-(a[i]+b[index[i]][0]), i, b[index[i]][1]))
        else:
            print(-v)
            return


if __name__ == "__main__":
    main()

ABC308 - F - Vouchers

あなたは店で $N$ 個の商品を買おうとしています。
$i$ 個目の商品の定価は $P_i$ 円です。
また、あなたは $M$ 枚のクーポンを持っています。
$i$ 枚目のクーポンを使うと、定価が $L_i$ 円以上の商品を一つ選び、その商品を定価より $D_i$ 円低い価格で買うことができます。
ここで、一つのクーポンは一回までしか使えません。また、複数のクーポンを同じ商品に重ねて使うことはできません。
クーポンを使わなかった商品は定価で買うことになります。
$N$ 個すべての商品を買うのに必要な最小の金額を求めてください。
制約
$1 ≤ N, M ≤ 2 × 10^5$
$1 ≤ P_i ≤ 10^9$
$1 ≤ D_i ≤ L_i ≤ 10^9$

特にうまく前処理やまとめて計算することもできなさそうな条件設定ですがとりあえずクーポンをできる限り使った方が、かつ使うなら使える中で一番割引額が大きいものを使いたいと考えるはずです。よって貪欲法で解こうという方針が立ちます。商品及びクーポンを金額の安い順に走査し、商品でクーポンが残っているならその中で一番割引額が大きいものを使うとすれば最適になります。これは使用可能なクーポンを可能な限り早く使用しても損しないことからわかります。よってHeapを用いることで $O(N\log{M})$ で解くことができました。

実装例
class BinaryHeap:
    def __init__(self):
        self.tree = []

    def push(self, x):
        p = len(self.tree)
        self.tree.append(x)
        self.float(p)

    def pop(self):
        # ないときにバグらないように
        if not self.tree: return 0
        p = len(self.tree)-1
        self.tree[0], self.tree[p] = self.tree[p], self.tree[0]
        res = self.tree.pop()
        self.sink(0)
        return res

    def pushpop(self, x):
        if self.tree[0] > x:
            return x
        else:
            res = self.tree[0]
            self.tree[0] = x
            self.sink(0)
            return res

    def float(self, p):
        x = self.tree[p]
        while p and self.tree[(p-1)//2] > x:
            par = (p-1)//2
            self.tree[par], self.tree[p] = x, self.tree[par]
            p = par

    def sink(self, p):
        while p*2+1 < len(self.tree):
            nex = 2*p+1
            if nex+1 < len(self.tree) and self.tree[nex] > self.tree[nex+1]:
                nex += 1
            if self.tree[p] > self.tree[nex]:
                self.tree[p], self.tree[nex] = self.tree[nex], self.tree[p]
                p = nex
            else: break


def main():
    n, m = map(int, input().split())
    p = list(map(int, input().split()))
    l = list(map(int, input().split()))
    d = list(map(int, input().split()))
    x = []
    for v in p:
        x.append((v, 1<<60))
    for v, sub in zip(l, d):
        x.append((v, -sub))
    x.sort()
    ans = 0
    heap = BinaryHeap()
    for i in range(n+m):
        if x[i][1] > 0:
            res = heap.pop()
            ans += x[i][0] + res
        else:
            heap.push(x[i][1])
    print(ans)


if __name__ == "__main__":
    main()

ABC306 - E - Best Performances

長さ $N$ の数列 $A = (A_1, A_2, …, A_N)$ があり、最初全ての項が $0$ です。入力で与えられる整数 $K$ を用いて $f(A)$ を以下のように定義します。
・ $A$ を降順ソートしたものを $B$ とする。
・ このとき $f(A) = B_1+B_2+…+B_K$ とする。
この数列に $Q$ 回更新を行うことを考えます。
数列 $A$ に対し以下の更新を $i=1, 2, …, Q$ について行い、各更新ごとにその時点での $f(A)$ の値を出力してください。
・ $A_{X_i}$ を $Y_i$ に変更する
$1 ≤ K ≤ N ≤ 5 × 10^5$
$1 ≤ Q ≤ 5 × 10^5$
$1 ≤ X_i ≤ N$
$1 ≤ Y_i ≤ 10^9$

要するに数列を更新しつつ最大 $K$ 個の和を出力せよということです。通常のBinaryHeapでは追跡して更新することができませんしこれはSoterdContainerとかsetかなぁと思われるでしょうしむしろそっちが正攻法だと思います。しかし!キー付きBinaryHeapなら変更や追跡もできます! $K$ 個以下に収まるものとそうでないものをそれぞれ最小ヒープ、最大ヒープに入れて変更したのち入れ替わるかをみつつ同時に答えを更新すれば答えを出すことができます。よってこの問題を $O(Q\log{N})$ で解くことができました。
(ちなみに実装例はかなり定数倍が悪いのですが制約上indexを配列で持つことができるので割と高速化できると思います。また非本質ですがinputをsys.stdin.readline()に置き換えることで定数倍高速化ができます。詳細は後日定数倍高速化で出しますが今は他の記事を参照してください。)

実装例

class KeyedBinaryHeap:
    # keyにlambda a, b: ~などの比較関数を乗せる
    def __init__(self, key=lambda a, b: a[0] < b[0]):
        self.tree = []
        self.dic = {}
        self.key = key
        from random import randint
        self.xor = randint(100, 1 << 60)

    def push(self, idx, x):
        if idx in self.dic:
            self.update(idx, x)
            return
        p = len(self.tree)
        self.tree.append((x, idx))
        self.dic[idx ^ self.xor] = p
        self._float(p)

    def pop(self):
        p = len(self.tree)-1
        self._swap(0, p)
        res = self.tree.pop()
        self._sink(0)
        del self.dic[res[1] ^ self.xor]
        return res

    def pushpop(self, idx, x):
        if idx ^ self.xor in self.dic:
            self.update(idx, x)
            return self.pop()
        if not self.tree and self.key((idx, x), self.tree[0]):
            return idx, x
        else:
            res = self.tree[0]
            self.tree[0] = (x, idx)
            self.dic[idx ^ self.xor] = 0
            del self.dic[res[1] ^ self.xor]
            self._sink(0)
            return res

    def get(self, idx):
        return self.tree[self.dic[idx ^ self.xor]][0]

    def update(self, idx, x):
        if idx ^ self.xor not in self.dic:
            return
        p = self.dic[idx ^ self.xor]
        self.tree[p] = (x, idx)
        self._float(p)
        p = self.dic[idx ^ self.xor]
        self._sink(p)

    def _float(self, p):
        x = self.tree[p]
        while p and self.key(x, self.tree[(p-1)//2]):
            par = (p-1)//2
            self._swap(p, par)
            p = par

    def _sink(self, p):
        while p*2+1 < len(self.tree):
            nex = 2*p+1
            if nex+1 < len(self.tree) and self.key(self.tree[nex+1], self.tree[nex]):
                nex += 1
            if self.key(self.tree[nex], self.tree[p]):
                self._swap(p, nex)
                p = nex
            else: break

    def _swap(self, x, y):
        _, k1 = self.tree[x]
        _, k2 = self.tree[y]
        self.dic[k1 ^ self.xor], self.dic[k2 ^ self.xor] = y, x
        self.tree[x], self.tree[y] = self.tree[y], self.tree[x]

    def __contains__(self, item):
        return item ^ self.xor in self.dic


def main():
    n, k, q = map(int, input().split())
    query = [list(map(int, input().split()))for _ in range(q)]
    heap1 = KeyedBinaryHeap(key=lambda x, y: x[0] < y[0])
    heap2 = KeyedBinaryHeap(key=lambda x, y: x[0] > y[0])
    heap2.push(-1 << 60, 0)
    for i in range(k):
        heap1.push(i, 0)
    for i in range(k, n):
        heap2.push(i, 0)
    ans = 0
    for idx, v in query:
        idx -= 1
        if idx in heap1:
            ans += v-heap1.get(idx)
            heap1.update(idx, v)
        else:
            heap2.update(idx, v)
        v1, k1 = heap1.pop()
        v2, k2 = heap2.pop()
        if v1 < v2:
            heap1.push(k2, v2)
            heap2.push(k1, v1)
            ans += v2-v1
        else:
            heap1.push(k1, v1)
            heap2.push(k2, v2)
        print(ans)


if __name__ == "__main__":
    main()

とりあえずは以上です。Heapは順序に関わらず最大最小を取ってこれるので頭の片隅に置いておくといいと思います。

参考文献

・秋葉拓哉, 岩田陽一 & 北川宜稔. 『プログラミングコンテストチャレンジブック : 問題解決のアルゴリズム活用力とコーディングテクニックを鍛える』. マイナビ出版, 2024.
・Python Software Foundation "heapq --- ヒープキューアルゴリズム", Python 3.13 ドキュメント, 2024, https://docs.python.org/ja/3/library/heapq.html

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?