LoginSignup
2
2

【Python】競プロer用サンプルコード集

Last updated at Posted at 2021-07-27

二分探索

任意の関数内の目的の値を$O(\log(n))$で探索するアルゴリズム

def binary_search(func, x, lo=0, hi=10**18):
    while lo < hi:
        mid = (lo+hi)//2
        if func(mid) < x: lo = mid+1
        else: hi = mid
    return lo
  • 戻り値: func(i)x以上となる最小のi
  • funcは広義単調増加関数である必要がある(途中で減少しない)

三分探索

関数の最小値を$O(\log(n))$で探索するアルゴリズム

def ternary_search(func, lo=0, hi=10**18, precision=10**-18, is_continuous=False):
    if is_continuous:
        while precision < 1:
            mid1, mid2 = (hi+2*lo)/3, (2*hi+lo)/3
            if func(mid1) <= func(mid2):
                hi = mid2
            else:
                lo = mid1
            precision *= 3/2
    else:
        while lo < hi:
            mid1, mid2 = (lo*2+hi)//3, (lo+hi*2)//3
            if func(mid1) <= func(mid2):
                hi = mid2
            else:
                lo = mid1+1
    return func(lo)
  • 戻り値は関数の最小値
  • 関数が最小値を取る$x$が欲しい場合はreturnする値をloに変更する
  • lo: 探索開始時の最小値、hi: 探索開始時の最大値、precision: 許容誤差
  • is_continuous: 連続→True、離散→False

基数変換

def convert_to_base(n, base):
    digits = []
    while n:
        digits.append(n % base)
        n //= base
    return digits

def convert_to_decimal(digits, base):
    return sum(x*base**i for i, x in enumerate(digits))
  • convert_to_baseは、nbase進数で表したときの各桁をリストで返す
  • convert_to_decimalは、各桁がdigitsで表されたbase進法の数を10進法に直す

ダイクストラ法

各頂点までの距離を$O((v+e)\log(v))$で求めるアルゴリズム

from heapq import heappop, heappush

def dijkstra(edge_list, s):
    n = len(edge_list)
    todo = []
    heappush(todo, (0, s))
    dist_list = [float('inf')]*n
    dist_list[s] = 0

    while todo:
        dist, node = heappop(todo)
        for to, cost in edge_list[node]:
            if dist_list[to] > cost + dist:
                dist_list[to] = cost + dist
                heappush(todo, (cost + dist, to))
    return dist_list
  • edge_listは各頂点ごとに(行き先頂点、距離)の集合を格納したリスト
  • 第二引数は始点
  • 戻り値は各頂点までの距離のリスト

UnionFind木

グループ分けを効率的に処理するための木構造

class UnionFind:
    def __init__(self, init_list, func=lambda x,y: x+y):
        self.n = len(init_list)
        self.parents = [-1] * self.n
        self.values = init_list[:]
        self.func = func
 
    def root(self, x):
        if not 0 <= x < self.n: raise Exception("uf index out of range")
        li = []
        while x >= 0:
            li.append(x)
            x = self.parents[x]
        y = li.pop()
        for x in li:
            self.parents[x] = y
        return y
 
    def find(self, x, y):
        return self.root(x) == self.root(y)
 
    def unite(self, x, y):
        r1 = self.root(x)
        r2 = self.root(y)
        if r1 == r2: return
        p1 = self.parents[r1]
        p2 = self.parents[r2]
        
        value = self.func(self.values[r1], self.values[r2])
        if p1 <= p2:
            self.parents[r2] = r1
            self.values[r1] = value
            if p1 == p2:
                self.parents[r1] -= 1
        else:
            self.parents[r1] = r2
            self.values[r2] = value
    
    def get_val(self, x):
        y = self.root(x)
        return self.values[y]
    
    def __len__(self):
        return self.n
    
    def __getitem__(self, x):
        return self.get_val(x)
    
    def print_unionfind(self):
        print(*[self.root(i) for i in range(self.n)])
  • self.parentsの要素は非負なら親のインデックス、負ならランクを格納
  • funcunite時にグループが持つ値同士の演算を指定
  • init_listで個々が持つ初期値を指定
  • デフォルトではget_valxと同じグループに属する個数が得られる

SortedSet

要素の追加、要素の削除、$x$以上の最小の要素の検索を$O(\sqrt{n})$で行えるデータ構造

こちらの記事を参照

  • 初期化はs = SortedSet([要素リスト])で行う
  • s.add(x)で要素の追加($O(\sqrt{n})$)
  • s.discard(x)で要素の削除($O(\sqrt{n})$)
  • s.lt(x)でxより小さい最大の要素、s.gt(x)でxより大きい最小の要素を返す
  • s.le(x)でx以下で最大の要素、s.ge(x)でx以上で最小の要素を返す

逆元の計算

逆元の計算を$O(\log MOD)$で行う。

def find_inv(a, MOD):
    c, u, v = MOD, 1, 0
    while c:
        t = a // c
        a, c = c, a-t*c
        u, v = v, u-t*v
    u %= MOD
    return u
  • Pythonには標準でpow関数が存在するが、PyPy3だとmodで負の冪乗を計算できなかったので自作
  • 例えばpow(5, -2, MOD)は、pow(find_inv(5, MOD), 2, MOD)により代用できる

modつき階乗

素数のmodが指定されている場合の階乗の計算を行う。$n!$までの計算をしたい時、前計算$O(n)$、毎計算$O(1)$で求めることができる。

class Factorial:
    def __init__(self, n, MOD):
        self.fc_list = [1, 1]
        self.fc_inv_list = [1, 1]
        self.inv_list = [1, 1]
        for i in range(2, n+1):
            self.inv_list.append(self.inv_list[MOD % i] * (MOD - MOD//i) % MOD)
            self.fc_inv_list.append(self.fc_inv_list[i-1]*self.inv_list[i] % MOD)
            self.fc_list.append(self.fc_list[i-1] * i % MOD)
    
    def fact(self, n):
        if n >= 0: return self.fc_list[n]
        else: return self.fc_inv_list[-n]
    
    def comb(self, n, m):
        if n < 0 or m < 0: raise Exception("n and m must not be negative")
        if n < m: return 0
        return self.fc_list[n]*self.fc_inv_list[m]*self.fc_inv_list[n-m] % MOD
    
    def perm(self, n, m):
        if n < 0 or m < 0: raise Exception("n and m must not be negative")
        if n < m: return 0
        return self.fc_list[n]*self.fc_inv_list[n-m] % MOD
  • fact(n)は$n!$、comb(n, m)は${}_n C_m$、perm(n, m)は${}_n P_m$にそれぞれ対応する
  • 初期化時のnより大きな整数を引数とした場合や、combpermで負の引数を指定した場合の動作は保証しない

行列の積

行列AとBの積を計算する。MODを指定すると各要素がi % MODで計算される。

def mat_mul(A, B, MOD=None):
    ax, ay = len(A[0]), len(A)
    bx, by = len(B[0]), len(B)
    if ax != by: raise Exception("dimension mismatch")
    
    C = [[0]*bx for _ in range(ay)]
    for y in range(ay):
        for x in range(bx):
            for i in range(ax):
                C[y][x] += A[y][i]*B[i][x]
                if MOD: C[y][x] %= MOD
    return C
  • PyPyでnumpyが使えないため自作
  • A = matmul(A, A)で行列累乗ができる

セグメント木

区間に対する操作を$O(\log n)$で行えるデータ構造

class SegTree:
    def __init__(self, init_list, func=lambda x,y: x+y, ide_ele=0):
        self.n = len(init_list)
        self.length = 1<<(self.n-1).bit_length()
        self.node_list = [ide_ele]*(2*self.length)
        self.func = func
        self.ide_ele = ide_ele
        for i in range(self.n):
            self.node_list[i+self.length] = init_list[i]
        for i in range(self.length-1, 0, -1):
            self.node_list[i] = self.func(self.node_list[2*i], self.node_list[2*i+1])
    
    def add(self, index, x):
        if not 0 <= index < self.n: raise Exception("segtree index out of range")
        index += self.length
        self.node_list[index] = self.func(self.node_list[index], x)
        while index > 1:
            self.node_list[index>>1] = self.func(self.node_list[index], self.node_list[index^1])
            index >>= 1
    
    def update(self, index, x):
        if not 0 <= index < self.n: raise Exception("segtree index out of range")
        index += self.length
        self.node_list[index] = x
        while index > 1:
            self.node_list[index>>1] = self.func(self.node_list[index], self.node_list[index^1])
            index >>= 1
    
    def query(self, l, r):
        if not (0 <= l <= self.n and 0 <= r <= self.n): raise Exception("segtree index out of range")
        ans = self.ide_ele
        l += self.length
        r += self.length
        while l < r:
            if l & 1:
                ans = self.func(ans, self.node_list[l])
                l += 1
            if r & 1:
                ans = self.func(ans, self.node_list[r-1])
            l >>= 1
            r >>= 1
        return ans
    
    def __len__(self):
        return self.n
    
    def __getitem__(self, index):
        if type(index) != int: raise Exception("segtree indices must be integers")
        if not 0 <= index < self.n: raise Exception("segtree index out of range")
        return self.node_list[index+self.length]
    
    def __setitem__(self, index, value):
        if type(index) != int: raise Exception("segtree indices must be integers")
        if not 0 <= index < self.n: raise Exception("segtree index out of range")
        self.update(index, value)
    
    def print_segtree(self):
        print(*self.node_list[self.length:self.length+self.n])

最大流(Dinic法)

重み付き有向グラフにおいて各辺の容量を超えずに点$s$から点$t$まで送ることのできる量の最大値
計算量は$O(|V|^2 |E|)$だが実用的にはかなり速い

from collections import deque

class Dinic:
    def __init__(self, n):
        self.n = n
        self.graph = [dict() for _ in range(n)]
        self.level_list = [-1]*n
    
    def add_edge(self, fr, to, cap, is_directed=True):
        if to in self.graph[fr]:
            self.graph[fr][to] += cap
            if not is_directed:
                self.graph[to][fr] += cap
        else:
            self.graph[fr][to] = cap
            if is_directed:
                self.graph[to][fr] = 0
            else:
                self.graph[to][fr] = cap
    
    def flow(self, fr, to, rate):
        rate = min(rate, self.graph[fr][to])
        self.graph[fr][to] -= rate
        self.graph[to][fr] += rate
        return rate
    
    def bfs(self, start, end):
        graph = self.graph
        self.level_list = [-1]*self.n
        self.level_list[start] = 0
        todo = deque([start])
        while todo:
            t = todo.popleft()
            lv = self.level_list[t] + 1
            for node, cap in graph[t].items():
                if cap > 0 and self.level_list[node] == -1:
                    self.level_list[node] = lv
                    if node == end: return True
                    todo.append(node)
        return False
    
    def dfs(self, start, end):
        level_list = self.level_list
        todo = [start]
        path = [(-1, -1, float('inf'))]
        while todo:
            t = todo[-1]
            u, v, d = path[-1]
            if v == t:
                todo.pop()
                path.pop()
            else:
                if v != -1: d = min(d, self.graph[v][t])
                path.append((v, t, d))
                if t == end:
                    for u, v, _ in path:
                        if u == -1: continue
                        self.flow(u, v, d)
                    return d
                for node, cap in self.graph[t].items():
                    if cap > 0 and level_list[node] > level_list[t]:
                        todo.append(node)
        return 0
    
    def max_flow(self, start, end):
        flow = 0
        while self.bfs(start, end):
            f = float('inf')
            while f:
                f = self.dfs(start, end)
                flow += f
        return flow

ローリングハッシュ

文字列の一致判定等を文字列の長さに比例しない計算量で行えるアルゴリズム

class RollingHash():
    def __init__(self, item_list, base_list, MOD=10**9+7):
        self.n = len(item_list)
        self.base = len(base_list)
        self.MOD = MOD
        self.base_dict = {base: i for i, base in enumerate(base_list)}
        self.hash_list = [0]
        
        rate = 1
        for i, item in enumerate(item_list):
            index = self.base_dict[item]
            x = (self.hash_list[-1]+rate*index) % MOD
            rate = rate*self.base % MOD
            self.hash_list.append(x)
    
    def get_hash(self, l, r):
        if l > r: raise Exception("r must be higher than l")
        if not (0 <= l and r <= self.n): raise Exception("hash index out of range")
        return (self.hash_list[r] - self.hash_list[l]) * pow(self.base, -l, self.MOD) % self.MOD
    
    def __len__(self):
        return self.n
  • item_listにはハッシュ化したい配列、base_listにはitem_listの構成要素を指定する
  • get_hashlrを指定しitem_list[l:r]のハッシュ値を得られる。

最小共通祖先

木の2つの頂点の最小共通祖先の探索を$O(\log n)$で行えるアルゴリズム


class LCA:
    def __init__(self, edge_list, root):
        self.root = root
        self.n = len(edge_list)
        self.m = (self.n-1).bit_length()
        self.edge_list = edge_list
        self.depth_list = [float('inf')]*self.n
        self.parents_list = [[-1]*self.m for _ in range(self.n)]
        
        todo = [root]
        self.depth_list[root] = 0
        while todo:
            t = todo.pop()
            for node in self.edge_list[t]:
                dist = self.depth_list[t]+1
                if self.depth_list[node] > dist:
                    todo.append(node)
                    self.depth_list[node] = dist
                    self.parents_list[node][0] = t
        for i in range(1, self.m):
            for j in range(self.n):
                node = self.parents_list[j][i-1]
                if node == -1: continue
                self.parents_list[j][i] = self.parents_list[node][i-1]
    
    def find_parents(self, node, depth):
        for i in range(depth.bit_length()):
            if depth & (1<<i): node = self.parents_list[node][i]
        return node
        
    def find_lca(self, u, v):
        delta = self.depth_list[v] - self.depth_list[u]
        if delta < 0: u, v, delta = v, u, -delta
        v = self.find_parents(v, delta)
        if u == v: return u
        for i in range(self.depth_list[u].bit_length()-1, -1, -1):
            pu, pv = self.parents_list[u][i], self.parents_list[v][i]
            if pu != pv: u, v = pu, pv
        return self.parents_list[u][0]

  • edge_listは頂点ごとに行き先の頂点番号の集合を格納したリスト

補足

公式リファレンスや蟻本に掲載されている実装例や、AtCoderのコードなどを参考にさせて頂きました。

参考

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2