LoginSignup
0
0

Python 自分用のライブラリが欲しい

Last updated at Posted at 2024-06-12

Library

Atcoder用pythonライブラリ関数版
まだ未完成
code : コードなし
code△ : コードはあるけどソースなし
code〇 : コードとソースあり

最短経路

ダイクストラ

code△
import heapq
def dijkstra(S,N,edge):
    # 始点,頂点数,辺集合(edge[u] = [v,w])を入力として
    # 各頂点の最小コストのリストを返す
    hq = [(0, S)]
    heapq.heapify(hq)
    cost = [float('inf')] * N
    cost[S] = 0
    while hq:
        c, v = heapq.heappop(hq)
        if c > cost[v]:continue
        for u, w in edge[v]:
            tmp = w + cost[v]
            if tmp < cost[u]:
                cost[u] = tmp
                heapq.heappush(hq, (tmp, u))
    return cost

ベルマンフォード

単一始点最短経路 (ベルマンフォード法)

code
code

トポロジカルソート

code
code

SCC

強連結成分分解
groupはトポロジカルソートされている
ABC357

code〇
import sys
sys.setrecursionlimit(10**8)
class Scc:
    def __init__(self,n):
        self.n = n
        self.edges = []

    def add_edge(self,fr,to):
        assert 0 <= fr < self.n
        assert 0 <= to < self.n
        self.edges.append((fr, to))

    def scc(self):
        csr_start = [0] * (self.n + 1)
        csr_elist = [0] * len(self.edges)
        for fr,to in self.edges:
            csr_start[fr + 1] += 1
        for i in range(1,self.n+1):
            csr_start[i] += csr_start[i-1]
        counter = csr_start[:]
        for fr,to in self.edges:
            csr_elist[counter[fr]] = to
            counter[fr] += 1

        self.now_ord = self.group_num = 0
        self.visited = []
        self.low = [0] * self.n
        self.ord = [-1] * self.n
        self.ids = [0] * self.n
        def _dfs(v):
            self.low[v] = self.ord[v] = self.now_ord
            self.now_ord += 1
            self.visited.append(v)
            for i in range(csr_start[v], csr_start[v+1]):
                to = csr_elist[i]
                if self.ord[to] == -1:
                    _dfs(to)
                    self.low[v] = min(self.low[v], self.low[to])
                else:
                    self.low[v] = min(self.low[v], self.ord[to])
            if self.low[v] == self.ord[v]:
                while 1:
                    u = self.visited.pop()
                    self.ord[u] = self.n
                    self.ids[u] = self.group_num
                    if u==v: break
                self.group_num += 1
        for i in range(self.n):
            if self.ord[i] == -1: _dfs(i)
        for i in range(self.n):
            self.ids[i] = self.group_num - 1 - self.ids[i]

        groups = [[] for _ in range(self.group_num)]
        for i in range(self.n):
            groups[self.ids[i]].append(i)
        return groups

グリッドDFS

code
code

サイクル検出

dfsをして閉路の頂点集合を返す
ChatGPT作
ABC266F

code〇
import sys
sys.setrecursionlimit(10 ** 8)

def detect_cycle(graph, N):
    visited = set()
    parent = {}
    cycle = []

    def dfs(v, prev):
        visited.add(v)
        parent[v] = prev

        for neighbor in graph[v]:
            if neighbor not in visited:
                if dfs(neighbor, v):
                    return True
            elif neighbor != prev:
                # Found a cycle
                cur = v
                while cur != neighbor:
                    cycle.append(cur)  # Convert to 1-based indexing
                    cur = parent[cur]
                cycle.append(neighbor)  # Convert to 1-based indexing
                cycle.append(v)  # Convert to 1-based indexing
                cycle.reverse()
                return True
        return False

    for vertex in range(N):
        if vertex not in visited:
            if dfs(vertex, None):
                return cycle[:-1]
    return None

グリッドBFS

code△
from collections import deque
INF = float('inf')
dir = [(1,0),(-1,0),(0,1),(0,-1)]
def inn(x,y): return 0<=x<H and 0<=y<W
def bfs(x,y):
    dist = [[INF for _ in range(W)] for _ in range(H)]
    dist[x][y] = 0
    q = deque({(x,y)})
    while q:
        x,y = q.popleft()
        for dx,dy in dir:
            xx=x+dx; yy=y+dy
            if inn(xx,yy) and not A[xx][yy]:
                if dist[xx][yy] > dist[x][y] + 1:
                    dist[xx][yy] = dist[x][y] + 1
                    q.append((xx,yy))
    return dist
    

二部グラフ

code△
def dfs(x,col):
    # 始点xからdfsをして,各色で何回塗ったかをcntで返す
    # 二部グラフでないときもcntを返すので注意
    colors[x] = col
    cnt[(col+1)//2]+=1
    for v in edge[x]:
        if colors[v]==col: return False
        if colors[v]==0 and not dfs(v,-col): return False
    return True

UnionFind

PythonでのUnion-Find(素集合データ構造)の実装と使い方

UnionFind

code〇
from collections import defaultdict
class UnionFind():
    def __init__(self, n):
        self.n = n
        self.parents = [-1] * n

    def find(self, x):
        if self.parents[x] < 0:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        if self.parents[x] > self.parents[y]:
            x, y = y, x

        self.parents[x] += self.parents[y]
        self.parents[y] = x

    def size(self, x):
        return -self.parents[self.find(x)]

    def same(self, x, y):
        return self.find(x) == self.find(y)

    def members(self, x):
        root = self.find(x)
        return [i for i in range(self.n) if self.find(i) == root]

    def roots(self):
        return [i for i, x in enumerate(self.parents) if x < 0]

    def group_count(self):
        return len(self.roots())

    def all_group_members(self):
        group_members = defaultdict(list)
        for member in range(self.n):
            group_members[self.find(member)].append(member)
        return group_members

# uf = UnionFind(N)

UnionFindLabel

code〇

class UnionFindLabel(UnionFind):
    def __init__(self, labels):
        assert len(labels) == len(set(labels))

        self.n = len(labels)
        self.parents = [-1] * self.n
        self.d = {x: i for i, x in enumerate(labels)}
        self.d_inv = {i: x for i, x in enumerate(labels)}

    def find_label(self, x):
        return self.d_inv[super().find(self.d[x])]

    def union(self, x, y):
        super().union(self.d[x], self.d[y])

    def size(self, x):
        return super().size(self.d[x])

    def same(self, x, y):
        return super().same(self.d[x], self.d[y])

    def members(self, x):
        root = self.find(self.d[x])
        return [self.d_inv[i] for i in range(self.n) if self.find(i) == root]

    def roots(self):
        return [self.d_inv[i] for i, x in enumerate(self.parents) if x < 0]

    def all_group_members(self):
        group_members = defaultdict(list)
        for member in range(self.n):
            group_members[self.d_inv[self.find(member)]].append(self.d_inv[member])
        return group_members


Weighted UnionFInd

code〇

class WeightedUnionFind:
    def __init__(self, n):
        self.par = [i for i in range(n+1)]
        self.rank = [0] * (n+1)
        # 根への距離を管理
        self.weight = [0] * (n+1)

    # 検索
    def find(self, x):
        if self.par[x] == x:
            return x
        else:
            y = self.find(self.par[x])
            # 親への重みを追加しながら根まで走査
            self.weight[x] += self.weight[self.par[x]]
            self.par[x] = y
            return y

    # 併合
    def union(self, x, y, w):
        rx = self.find(x)
        ry = self.find(y)
        # xの木の高さ < yの木の高さ
        if self.rank[rx] < self.rank[ry]:
            self.par[rx] = ry
            self.weight[rx] = w - self.weight[x] + self.weight[y]
        # xの木の高さ ≧ yの木の高さ
        else:
            self.par[ry] = rx
            self.weight[ry] = -w - self.weight[y] + self.weight[x]
            # 木の高さが同じだった場合の処理
            if self.rank[rx] == self.rank[ry]:
                self.rank[rx] += 1

    # 同じ集合に属するか
    def same(self, x, y):
        return self.find(x) == self.find(y)

    # xからyへのコスト
    def diff(self, x, y):
        return self.weight[x] - self.weight[y]


SortedSet

リンク

SortedSet

code〇
# https://github.com/tatyam-prime/SortedSet/blob/main/SortedSet.py
import math
from bisect import bisect_left, bisect_right
from typing import Generic, Iterable, Iterator, List, Tuple, TypeVar, Optional
T = TypeVar('T')

class SortedSet(Generic[T]):
    BUCKET_RATIO = 16
    SPLIT_RATIO = 24
    
    def __init__(self, a: Iterable[T] = []) -> None:
        "Make a new SortedSet from iterable. / O(N) if sorted and unique / O(N log N)"
        a = list(a)
        n = len(a)
        if any(a[i] > a[i + 1] for i in range(n - 1)):
            a.sort()
        if any(a[i] >= a[i + 1] for i in range(n - 1)):
            a, b = [], a
            for x in b:
                if not a or a[-1] != x:
                    a.append(x)
        n = self.size = len(a)
        num_bucket = int(math.ceil(math.sqrt(n / self.BUCKET_RATIO)))
        self.a = [a[n * i // num_bucket : n * (i + 1) // num_bucket] for i in range(num_bucket)]

    def __iter__(self) -> Iterator[T]:
        for i in self.a:
            for j in i: yield j

    def __reversed__(self) -> Iterator[T]:
        for i in reversed(self.a):
            for j in reversed(i): yield j
    
    def __eq__(self, other) -> bool:
        return list(self) == list(other)
    
    def __len__(self) -> int:
        return self.size
    
    def __repr__(self) -> str:
        return "SortedSet" + str(self.a)
    
    def __str__(self) -> str:
        s = str(list(self))
        return "{" + s[1 : len(s) - 1] + "}"

    def _position(self, x: T) -> Tuple[List[T], int, int]:
        "return the bucket, index of the bucket and position in which x should be. self must not be empty."
        for i, a in enumerate(self.a):
            if x <= a[-1]: break
        return (a, i, bisect_left(a, x))

    def __contains__(self, x: T) -> bool:
        if self.size == 0: return False
        a, _, i = self._position(x)
        return i != len(a) and a[i] == x

    def add(self, x: T) -> bool:
        "Add an element and return True if added. / O(√N)"
        if self.size == 0:
            self.a = [[x]]
            self.size = 1
            return True
        a, b, i = self._position(x)
        if i != len(a) and a[i] == x: return False
        a.insert(i, x)
        self.size += 1
        if len(a) > len(self.a) * self.SPLIT_RATIO:
            mid = len(a) >> 1
            self.a[b:b+1] = [a[:mid], a[mid:]]
        return True
    
    def _pop(self, a: List[T], b: int, i: int) -> T:
        ans = a.pop(i)
        self.size -= 1
        if not a: del self.a[b]
        return ans

    def discard(self, x: T) -> bool:
        "Remove an element and return True if removed. / O(√N)"
        if self.size == 0: return False
        a, b, i = self._position(x)
        if i == len(a) or a[i] != x: return False
        self._pop(a, b, i)
        return True
    
    def lt(self, x: T) -> Optional[T]:
        "Find the largest element < x, or None if it doesn't exist."
        for a in reversed(self.a):
            if a[0] < x:
                return a[bisect_left(a, x) - 1]

    def le(self, x: T) -> Optional[T]:
        "Find the largest element <= x, or None if it doesn't exist."
        for a in reversed(self.a):
            if a[0] <= x:
                return a[bisect_right(a, x) - 1]

    def gt(self, x: T) -> Optional[T]:
        "Find the smallest element > x, or None if it doesn't exist."
        for a in self.a:
            if a[-1] > x:
                return a[bisect_right(a, x)]

    def ge(self, x: T) -> Optional[T]:
        "Find the smallest element >= x, or None if it doesn't exist."
        for a in self.a:
            if a[-1] >= x:
                return a[bisect_left(a, x)]
    
    def __getitem__(self, i: int) -> T:
        "Return the i-th element."
        if i < 0:
            for a in reversed(self.a):
                i += len(a)
                if i >= 0: return a[i]
        else:
            for a in self.a:
                if i < len(a): return a[i]
                i -= len(a)
        raise IndexError
    
    def pop(self, i: int = -1) -> T:
        "Pop and return the i-th element."
        if i < 0:
            for b, a in enumerate(reversed(self.a)):
                i += len(a)
                if i >= 0: return self._pop(a, ~b, i)
        else:
            for b, a in enumerate(self.a):
                if i < len(a): return self._pop(a, b, i)
                i -= len(a)
        raise IndexError
    
    def index(self, x: T) -> int:
        "Count the number of elements < x."
        ans = 0
        for a in self.a:
            if a[-1] >= x:
                return ans + bisect_left(a, x)
            ans += len(a)
        return ans

    def index_right(self, x: T) -> int:
        "Count the number of elements <= x."
        ans = 0
        for a in self.a:
            if a[-1] > x:
                return ans + bisect_right(a, x)
            ans += len(a)
        return ans
        

SortedMultiSet

code〇
# https://github.com/tatyam-prime/SortedSet/blob/main/SortedMultiset.py
import math
from bisect import bisect_left, bisect_right
from typing import Generic, Iterable, Iterator, List, Tuple, TypeVar, Optional
T = TypeVar('T')

class SortedMultiset(Generic[T]):
    BUCKET_RATIO = 16
    SPLIT_RATIO = 24
    
    def __init__(self, a: Iterable[T] = []) -> None:
        "Make a new SortedMultiset from iterable. / O(N) if sorted / O(N log N)"
        a = list(a)
        n = self.size = len(a)
        if any(a[i] > a[i + 1] for i in range(n - 1)):
            a.sort()
        num_bucket = int(math.ceil(math.sqrt(n / self.BUCKET_RATIO)))
        self.a = [a[n * i // num_bucket : n * (i + 1) // num_bucket] for i in range(num_bucket)]

    def __iter__(self) -> Iterator[T]:
        for i in self.a:
            for j in i: yield j

    def __reversed__(self) -> Iterator[T]:
        for i in reversed(self.a):
            for j in reversed(i): yield j
    
    def __eq__(self, other) -> bool:
        return list(self) == list(other)
    
    def __len__(self) -> int:
        return self.size
    
    def __repr__(self) -> str:
        return "SortedMultiset" + str(self.a)
    
    def __str__(self) -> str:
        s = str(list(self))
        return "{" + s[1 : len(s) - 1] + "}"

    def _position(self, x: T) -> Tuple[List[T], int, int]:
        "return the bucket, index of the bucket and position in which x should be. self must not be empty."
        for i, a in enumerate(self.a):
            if x <= a[-1]: break
        return (a, i, bisect_left(a, x))

    def __contains__(self, x: T) -> bool:
        if self.size == 0: return False
        a, _, i = self._position(x)
        return i != len(a) and a[i] == x

    def count(self, x: T) -> int:
        "Count the number of x."
        return self.index_right(x) - self.index(x)

    def add(self, x: T) -> None:
        "Add an element. / O(√N)"
        if self.size == 0:
            self.a = [[x]]
            self.size = 1
            return
        a, b, i = self._position(x)
        a.insert(i, x)
        self.size += 1
        if len(a) > len(self.a) * self.SPLIT_RATIO:
            mid = len(a) >> 1
            self.a[b:b+1] = [a[:mid], a[mid:]]
    
    def _pop(self, a: List[T], b: int, i: int) -> T:
        ans = a.pop(i)
        self.size -= 1
        if not a: del self.a[b]
        return ans

    def discard(self, x: T) -> bool:
        "Remove an element and return True if removed. / O(√N)"
        if self.size == 0: return False
        a, b, i = self._position(x)
        if i == len(a) or a[i] != x: return False
        self._pop(a, b, i)
        return True

    def lt(self, x: T) -> Optional[T]:
        "Find the largest element < x, or None if it doesn't exist."
        for a in reversed(self.a):
            if a[0] < x:
                return a[bisect_left(a, x) - 1]

    def le(self, x: T) -> Optional[T]:
        "Find the largest element <= x, or None if it doesn't exist."
        for a in reversed(self.a):
            if a[0] <= x:
                return a[bisect_right(a, x) - 1]

    def gt(self, x: T) -> Optional[T]:
        "Find the smallest element > x, or None if it doesn't exist."
        for a in self.a:
            if a[-1] > x:
                return a[bisect_right(a, x)]

    def ge(self, x: T) -> Optional[T]:
        "Find the smallest element >= x, or None if it doesn't exist."
        for a in self.a:
            if a[-1] >= x:
                return a[bisect_left(a, x)]
    
    def __getitem__(self, i: int) -> T:
        "Return the i-th element."
        if i < 0:
            for a in reversed(self.a):
                i += len(a)
                if i >= 0: return a[i]
        else:
            for a in self.a:
                if i < len(a): return a[i]
                i -= len(a)
        raise IndexError
    
    def pop(self, i: int = -1) -> T:
        "Pop and return the i-th element."
        if i < 0:
            for b, a in enumerate(reversed(self.a)):
                i += len(a)
                if i >= 0: return self._pop(a, ~b, i)
        else:
            for b, a in enumerate(self.a):
                if i < len(a): return self._pop(a, b, i)
                i -= len(a)
        raise IndexError

    def index(self, x: T) -> int:
        "Count the number of elements < x."
        ans = 0
        for a in self.a:
            if a[-1] >= x:
                return ans + bisect_left(a, x)
            ans += len(a)
        return ans

    def index_right(self, x: T) -> int:
        "Count the number of elements <= x."
        ans = 0
        for a in self.a:
            if a[-1] > x:
                return ans + bisect_right(a, x)
            ans += len(a)
        return ans
        

Segment Tree

セグ木
遅延セグ木

func_list

code〇
L = [0]*N
seg = SegTree(L, lambda x,y:min(x,y), float('inf'))
seg = SegTree(L, lambda x,y:max(x,y), -float('inf'))
seg = SegTree(L, lambda x,y:x+y, 0)
seg = SegTree(L, lambda x,y:x*y, 1)
seg = SegTree(L, lambda x,y:math.gcd, 0)

Segment Tree

一点更新、一点加算、区間和 他いろいろ

code〇
def segfunc(x,y):
    return x+y
class SegTree:
    def __init__(self,init_val,segfunc,ide_ele):
        n = len(init_val)
        self.segfunc = segfunc
        self.ide_ele = ide_ele
        self.num = 1<<(n-1).bit_length()
        self.tree = [ide_ele]*2*self.num
        for i in range(n):
            self.tree[self.num+i] = init_val[i]
        for i in range(self.num-1,0,-1):
            self.tree[i] = self.segfunc(self.tree[2*i],self.tree[2*i+1])
    def add(self,k,x):
        k += self.num
        self.tree[k] += x
        while k>1:
            self.tree[k>>1] = self.segfunc(self.tree[k],self.tree[k^1])
            k >>= 1
    def update(self,k,x):
        k += self.num
        self.tree[k] = x
        while k>1:
            self.tree[k>>1] = self.segfunc(self.tree[k],self.tree[k^1])
            k >>= 1
    def query(self,l,r):
        res = self.ide_ele
        l += self.num
        r += self.num
        while l<r:
            if l&1:
                res = self.segfunc(res,self.tree[l])
                l += 1
            if r&1:
                res = self.segfunc(res,self.tree[r-1])
            l >>= 1
            r >>= 1
        return res

Fenwick Tree

一点加算,区間和
ABC276

code〇
class Fenwick_Tree:
    def __init__(self, n):
        self._n = n
        self.data = [0] * n
    
    def add(self, p, x):
        assert 0 <= p < self._n
        p += 1
        while p <= self._n:
            self.data[p - 1] += x
            p += p & -p
    
    def sum(self, l, r):
        assert 0 <= l <= r <= self._n
        return self._sum(r) - self._sum(l)
    
    def _sum(self, r):
        s = 0
        while r > 0:
            s += self.data[r - 1]
            r -= r & -r
        return s

Lazy Segment Tree(区間加算)

code〇
def segfunc(x,y):
    return x+y
class LazySegTree_RAQ:
    def __init__(self,init_val,segfunc,ide_ele):
        n = len(init_val)
        self.segfunc = segfunc
        self.ide_ele = ide_ele
        self.num = 1<<(n-1).bit_length()
        self.tree = [ide_ele]*2*self.num
        self.lazy = [0]*2*self.num
        for i in range(n):
            self.tree[self.num+i] = init_val[i]
        for i in range(self.num-1,0,-1):
            self.tree[i] = self.segfunc(self.tree[2*i], self.tree[2*i+1])
    def gindex(self,l,r):
        l += self.num
        r += self.num
        lm = l>>(l&-l).bit_length()
        rm = r>>(r&-r).bit_length()
        while r>l:
            if l<=lm:
                yield l
            if r<=rm:
                yield r
            r >>= 1
            l >>= 1
        while l:
            yield l
            l >>= 1
    def propagates(self,*ids):
        for i in reversed(ids):
            v = self.lazy[i]
            if v==0:
                continue
            self.lazy[i] = 0
            self.lazy[2*i] += v
            self.lazy[2*i+1] += v
            self.tree[2*i] += v
            self.tree[2*i+1] += v
    def add(self,l,r,x):
        ids = self.gindex(l,r)
        l += self.num
        r += self.num
        while l<r:
            if l&1:
                self.lazy[l] += x
                self.tree[l] += x
                l += 1
            if r&1:
                self.lazy[r-1] += x
                self.tree[r-1] += x
            r >>= 1
            l >>= 1
        for i in ids:
            self.tree[i] = self.segfunc(self.tree[2*i], self.tree[2*i+1]) + self.lazy[i]
    def query(self,l,r):
        self.propagates(*self.gindex(l,r))
        res = self.ide_ele
        l += self.num
        r += self.num
        while l<r:
            if l&1:
                res = self.segfunc(res,self.tree[l])
                l += 1
            if r&1:
                res = self.segfunc(res,self.tree[r-1])
            l >>= 1
            r >>= 1
        return res

Lazy Segment Tree(区間更新)

code〇
def segfunc(x,y):
    return min(x,y)
class LazySegTree_RUQ:
    def __init__(self,init_val,segfunc,ide_ele):
        n = len(init_val)
        self.segfunc = segfunc
        self.ide_ele = ide_ele
        self.num = 1<<(n-1).bit_length()
        self.tree = [ide_ele]*2*self.num
        self.lazy = [None]*2*self.num
        for i in range(n):
            self.tree[self.num+i] = init_val[i]
        for i in range(self.num-1,0,-1):
            self.tree[i] = self.segfunc(self.tree[2*i],self.tree[2*i+1])
    def gindex(self,l,r):
        l += self.num
        r += self.num
        lm = l>>(l&-l).bit_length()
        rm = r>>(r&-r).bit_length()
        while r>l:
            if l<=lm:
                yield l
            if r<=rm:
                yield r
            r >>= 1
            l >>= 1
        while l:
            yield l
            l >>= 1
    def propagates(self,*ids):
        for i in reversed(ids):
            v = self.lazy[i]
            if v is None:
                continue
            self.lazy[i] = None
            self.lazy[2*i] = v
            self.lazy[2*i+1] = v
            self.tree[2*i] = v
            self.tree[2*i+1] = v
    def update(self,l,r,x):
        ids = self.gindex(l,r)
        self.propagates(*self.gindex(l,r))
        l += self.num
        r += self.num
        while l<r:
            if l&1:
                self.lazy[l] = x
                self.tree[l] = x
                l += 1
            if r&1:
                self.lazy[r-1] = x
                self.tree[r-1] = x
            r >>= 1
            l >>= 1
        for i in ids:
            self.tree[i] = self.segfunc(self.tree[2*i], self.tree[2*i+1])
    def query(self,l,r):
        ids = self.gindex(l,r)
        self.propagates(*self.gindex(l,r))
        res = self.ide_ele
        l += self.num
        r += self.num
        while l<r:
            if l&1:
                res = self.segfunc(res,self.tree[l])
                l += 1
            if r&1:
                res = self.segfunc(res,self.tree[r-1])
            l >>= 1
            r >>= 1
        return res

LCA

ブログ

LCA

最小共通祖先とその距離を求める
ABC014

code〇
from collections import deque
import sys
sys.setrecursionlimit(10 ** 8)

class LCA:
    def __init__(self, G, root=0):
        V = len(G)
        K = 1
        while 1<<K < V:
            K += 1
        self.parent = [[-1]*V for _ in range(K)]
        self.dist = [-1]*V
        self.bfs(G, root)
        for i in range(K-1):
            for j in range(V):
                if self.parent[i][j] != -1:
                    self.parent[i+1][j] = self.parent[i][self.parent[i][j]]
        
    def bfs(self, G, root):
        self.dist[root] = 0
        que = deque()
        que.append(root)
        while que:
            n = que.popleft()
            for v in G[n]:
                if self.dist[v] == -1:
                    self.dist[v] = self.dist[n]+1
                    self.parent[0][v] = n
                    que.append(v)
    
    def query(self, a, b):
        if self.dist[a] < self.dist[b]:
            a, b = b, a
        K = len(self.parent)
        for i in range(K):
            if (self.dist[a]-self.dist[b]) & 1<<i:
                a = self.parent[i][a]
        if a == b:
            return a
        for i in reversed(range(K)):
            if self.parent[i][a] != self.parent[i][b]:
                a = self.parent[i][a]
                b = self.parent[i][b]
        return self.parent[0][a]

パスの最大重みを求める

ダブリングと呼ぶらしい
めちゃマイナー
ABC235

code〇

class SecondMinimumSpanningTreeLight:
    def __init__(self, graph):
        self.n = len(graph)
        self.parent = [-1] * self.n
        self.depth = [-1] * self.n
        stack = deque([0])
        self.depth[0] = 0
        while stack:
            i = stack.popleft()
            for j in graph[i]:
                if self.depth[j] == -1:
                    self.depth[j] = self.depth[i] + 1
                    self.parent[j] = i
                    stack.append(j)

        self.cols = max(2, math.ceil(math.log2(self.n)))
        self.dp = [-1] * self.cols * self.n
        self.maximum_weight = [-1] * self.cols * self.n
        for i in range(self.n):
            self.dp[i * self.cols] = self.parent[i]
            if self.parent[i] != -1:
                self.maximum_weight[i * self.cols] = graph[self.parent[i]][i]

        for j in range(1, self.cols):
            for i in range(self.n):
                ancestor = self.dp[i * self.cols + j - 1]
                max_weight = self.maximum_weight[i * self.cols + j - 1]
                self.maximum_weight[i * self.cols + j] = max_weight
                if ancestor != -1:
                    self.dp[i * self.cols + j] = self.dp[ancestor * self.cols + j - 1]
                    max_weight = max(self.maximum_weight[i * self.cols + j], self.maximum_weight[ancestor * self.cols + j - 1])
                    self.maximum_weight[i * self.cols + j] = max_weight

    def get_maximum_weight(self, x, y):
        if self.depth[x] < self.depth[y]:
            x, y = y, x
        max_weight = -1
        while self.depth[x] > self.depth[y]:
            diff = self.depth[x] - self.depth[y]
            max_weight = max(max_weight, self.maximum_weight[x * self.cols + int(math.log2(diff))])
            x = self.dp[x * self.cols + int(math.log2(diff))]
        if x == y:
            return max_weight


graph = [dict() for _ in range(n)]
for a, b, c in edges:
    a -= 1
    b -= 1
    graph[a][b] = c
    graph[b][a] = c

mst = SecondMinimumSpanningTreeLight(graph)

二分探索

bisect

参考

code〇
import bisect
def index(a, x):   # 探索したい数値のindexを探索
    'Locate the leftmost value exactly equal to x'
    i = bisect.bisect_left(a, x)
    if i != len(a) and a[i] == x:
        return i
    else:return -1
    
def find_lt(a, x):   # 探索したい数値未満のうち最大の数値を探索
    'Find rightmost value less than x'
    i = bisect.bisect_left(a, x)
    if i:
        return a[i-1]
    else:return -1

def find_le(a, x):   # 探索したい数値以下のうち最大の数値を探索
    'Find rightmost value less than or equal to x'
    i = bisect.bisect_right(a, x)
    if i:
        return a[i-1]
    else:return -1

def find_gt(a, x):   # 探索したい数値を超えるもののうち最小の数値を探索
    'Find leftmost value greater than x'
    i = bisect.bisect_right(a, x)
    if i != len(a):
        return a[i]
    else:return -1

def find_ge(a, x):   # 探索したい数値以上のうち最小の数値を探索
    'Find leftmost item greater than or equal to x'
    i = bisect.bisect_left(a, x)
    if i != len(a):
        return a[i]
    else:return -1

めぐる式二分探索

参考

code〇
def is_ok(arg):
    # 条件を満たすかどうか?問題ごとに定義
    pass
      
def meguru_bisect(ng, ok):
    '''
    初期値のng,okを受け取り,is_okを満たす最小(最大)のokを返す
    まずis_okを定義すべし
    ng ok は  とり得る最小の値-1 とり得る最大の値+1
    最大最小が逆の場合はよしなにひっくり返す
    '''
    while (abs(ok - ng) > 1):
        mid = (ok + ng) // 2
        if is_ok(mid):
            ok = mid
        else:
            ng = mid
    return ok

dp

ナップザック

yaketake08’s 実装メモ(ナップサック問題 (重みが小さいケース))
yaketake08’s 実装メモ(ナップサック問題 (価値が小さいケース))

  • $N$ 種類の品物がある
  • $i$ 番目の品物の価値は $vi$, 容量は $wi$, 個数は $ci$
  • 重さの総和 $W$ まで入るナップサックに入れる
  • ナップサックに入る品物の価値を最大化する

0-1ナップザック:重みが小さいケース

$c_i = 1$
計算量は$O(NW)$

code〇
# i番目の重みws[i], 価値vs[i]
def solve(N, W, ws, vs):
    dp = [0] * (W+1)
    for i in range(N):
        # 価値v, 重さw
        v = vs[i]; w = ws[i]
        for j in range(W, w-1, -1):
            dp[j] = max(dp[j-w] + v, dp[j])
    return max(dp)

個数制限なしナップザック:重みが小さいケース

$c_i$は無制限
計算量は$O(NW)$

code〇
# i番目の重みws[i], 価値vs[i]
def solve(N, W, ws, vs):
    dp = [0] * (W+1)
    for i in range(N):
        # 価値v, 重さw
        v = vs[i]; w = ws[i]
        for j in range(w, W+1):
            dp[j] = max(dp[j-w] + v, dp[j])
    return max(dp)

0-1ナップザック:価値が小さいケース

計算量は$O(N^2\max_i{v_i})$

code〇
# i番目の重みws[i], 価値vs[i]
# i番目の重みws[i], 価値vs[i]
def solve(N, W, ws, vs):
    # V = 全ての品物の価値の総和
    V = sum(vs)
    
    # 初期値は価値0以外の重さを上限より大きく
    dp = [W+1] * (V + 1)
    dp[0] = 0
    for i in range(N):
        # 価値v, 重さw
        v = vs[i]; w = ws[i]
        for j in range(V, v-1, -1):
            dp[j] = min(dp[j-v] + w, dp[j])

    # 重さが上限以下の価値のうち、最大の価値が解
    return max(i for i in range(V+1) if dp[i] <= W)

巡回セールスマン問題

code〇

def tsp(dist,N):
    """
    0 ~ N-3: 街
    N-2: 始点
    N-1: 終点
    dist[i][j]: i ~ j の移動コスト
    dp[S][i]: 訪問済み街の集合:S, 現在いる街:i
    O(N * 2^N)
    """
    n = N - 2
    dp = [[INF for _ in range(n)] for _ in range(1<<n)]

    for i in range(n):
        dp[1<<i][i] = dist[-2][i]

    for mask in range(1 << n):
        for u in range(n):
            if mask & (1 << u): 
                for v in range(n):
                    if not (mask & (1 << v)):  
                        new_mask = mask | (1 << v)
                        dp[new_mask][v] = min(dp[new_mask][v], dp[mask][u] + dist[u][v])
    return dp
    
##

その他

マージソート

code×
code

転倒数

ABC244D
反転数

code〇
# 0-indexの数列で使える
def inversion(inds,N):
    bit = [0] * (N+1)
    def bit_add(x,w):
        while x <= N:
            bit[x] += w
            x += (x & -x)
    def bit_sum(x):
        ret = 0
        while x > 0:
            ret += bit[x]
            x -= (x & -x)
        return ret
    inv = 0
    for ind in reversed(inds):
        inv += bit_sum(ind + 1)
        bit_add(ind + 1, 1)
    return inv
    
# 任意の配列で使えるバージョン
def inversion(in_arr):
    A,B = in_arr, sorted(in_arr)
    arr = []
    dic = {}
    for i,b in enumerate(B):
        if b in dic:
            arr[dic[b]].append(i)
        else:
            dic[b] = len(arr)
            arr.append([i])
    for i in range(len(arr)):
        arr[i].reverse()
    inds = []
    for a in A:
        inds.append(arr[dic[a]].pop())
    L = len(inds)
    bit = [0] * (L+1)
    def bit_add(x,w):
        while x <= L:
            bit[x] += w
            x += (x & -x)
    def bit_sum(x):
        ret = 0
        while x > 0:
            ret += bit[x]
            x -= (x & -x)
        return ret
    inv = 0
    for ind in reversed(inds):
        inv += bit_sum(ind + 1)
        bit_add(ind + 1, 1)
    return inv
    

素因数分解

yaketake08’s 実装メモ(素因数分解 (試し割り法))

code〇
# N: 素因数分解する数

def factorization(N):
    res = []
    x = N
    y = 2
    while y*y <= x:
        while x % y == 0:
            res.append(y)
            x //= y
        y += 1
    if x > 1:
        res.append(x)
    return res

素数テーブル作成

code×
code

拡張ユークリッド

yaketake08’s 実装メモ(最大公約数 ((拡張)ユークリッドの互除法))

code○
# Euclidean Algorithm
def gcd(m, n):
    r = m % n
    return gcd(n, r) if r else n

# Euclidean Algorithm (non-recursive)
def gcd2(m, n):
    while n:
        m, n = n, m % n
    return m

# Extended Euclidean Algorithm
def extgcd(a, b):
    if b:
        d, y, x = extgcd(b, a % b)
        y -= (a // b)*x
        return d, x, y
    return a, 1, 0

# lcm (least common multiple)
def lcm(m, n):
    return m//gcd(m, n)*n

zアルゴリズム

SとS[i:]の最長共通接頭辞の長さのリストを返す
ABC284
z-algorithm

code〇
def z(S):
    N = len(S)
    A = [0 for _ in range(N)]
    # A[i] : SとS[i:]の最長共通接頭辞の長さ
    i=1; j=0
    while i<N:
        while i+j<N and S[j]==S[i+j]: j+=1
        A[i] = j
        if j==0: i+=1; continue
        k = 1
        while i+k<N and k+A[k]<j: A[i+k] = A[k]; k+=1
        i+=k; j-=k
    return A

最長共通部分列

yaketake08’s 実装メモ(最長共通部分列)
計算量は$O(NM)$

code
def solve(S, T):
    L1 = len(S)
    L2 = len(T)
    dp = [[0]*(L2+1) for i in range(L1+1)]

    for i in range(L1-1, -1, -1):
        for j in range(L2-1, -1, -1):
            r = max(dp[i+1][j], dp[i][j+1])
            if S[i] == T[j]:
                r = max(r, dp[i+1][j+1] + 1)
            dp[i][j] = r

    # dp[0][0] が長さの解

    # ここからは復元処理
    res = []
    i = 0; j = 0
    while i < L1 and j < L2:
        if S[i] == T[j]:
            res.append(S[i])
            i += 1; j += 1
        elif dp[i][j] == dp[i+1][j]:
            i += 1
        elif dp[i][j] == dp[i][j+1]:
            j += 1
    return "".join(res)

最長回文

最長回文

code×
code

幾何

yaketake08’s 実装メモ
リンク略

  • 点の線対称

  • 多角形の面積

  • 多角形の点包含判定

  • 線分同士の交差判定

  • 線分と頂点の最短距離

  • 直線同士の交点

  • 直線(線分)と円の交点

  • 円同士の交点

  • 2つの円の共通部分の面積

  • 円の共通接線の接点/接線

  • 三角形の外接円/内接円/傍接円

  • 凸包 (Graham Scan)

  • 凸多角形の点包含判定

  • 凸多角形同士の交差判定/交点

  • 直線による凸多角形の切断

  • 凸多角形と直線の交差判定/交点

  • 点から凸多角形への接線

  • 最遠点対 (キャリパー法)

  • 最近点対 (分割統治法)

  • 座標圧縮

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0