Library
Atcoder用pythonライブラリ関数版
まだ未完成
code : コードなし
code△ : コードはあるけどソースなし
code〇 : コードとソースあり
最短経路
ダイクストラ
code△
import heapq
def dijkstra(S,N,edge):
# 始点,頂点数,辺集合(edge[u] = [v,w])を入力として
# 各頂点の最小コストのリストを返す
hq = [(0, S)]
heapq.heapify(hq)
cost = [float('inf')] * N
cost[S] = 0
while hq:
c, v = heapq.heappop(hq)
if c > cost[v]:continue
for u, w in edge[v]:
tmp = w + cost[v]
if tmp < cost[u]:
cost[u] = tmp
heapq.heappush(hq, (tmp, u))
return cost
ベルマンフォード
code
code
トポロジカルソート
code
code
SCC
強連結成分分解
group
はトポロジカルソートされている
ABC357
code〇
import sys
sys.setrecursionlimit(10**8)
class Scc:
def __init__(self,n):
self.n = n
self.edges = []
def add_edge(self,fr,to):
assert 0 <= fr < self.n
assert 0 <= to < self.n
self.edges.append((fr, to))
def scc(self):
csr_start = [0] * (self.n + 1)
csr_elist = [0] * len(self.edges)
for fr,to in self.edges:
csr_start[fr + 1] += 1
for i in range(1,self.n+1):
csr_start[i] += csr_start[i-1]
counter = csr_start[:]
for fr,to in self.edges:
csr_elist[counter[fr]] = to
counter[fr] += 1
self.now_ord = self.group_num = 0
self.visited = []
self.low = [0] * self.n
self.ord = [-1] * self.n
self.ids = [0] * self.n
def _dfs(v):
self.low[v] = self.ord[v] = self.now_ord
self.now_ord += 1
self.visited.append(v)
for i in range(csr_start[v], csr_start[v+1]):
to = csr_elist[i]
if self.ord[to] == -1:
_dfs(to)
self.low[v] = min(self.low[v], self.low[to])
else:
self.low[v] = min(self.low[v], self.ord[to])
if self.low[v] == self.ord[v]:
while 1:
u = self.visited.pop()
self.ord[u] = self.n
self.ids[u] = self.group_num
if u==v: break
self.group_num += 1
for i in range(self.n):
if self.ord[i] == -1: _dfs(i)
for i in range(self.n):
self.ids[i] = self.group_num - 1 - self.ids[i]
groups = [[] for _ in range(self.group_num)]
for i in range(self.n):
groups[self.ids[i]].append(i)
return groups
グリッドDFS
code
code
サイクル検出
dfsをして閉路の頂点集合を返す
ChatGPT作
PAST13I
code〇
import sys
sys.setrecursionlimit(10 ** 8)
def detect_cycle(edge):
def dfs(v, visited, rec_stack):
visited[v] = True
rec_stack[v] = True
for neighbor in edge[v]:
if not visited[neighbor]:
if dfs(neighbor, visited, rec_stack):
return True
elif rec_stack[neighbor]:
return True
rec_stack[v] = False
return False
N = len(edge)
visited = [False] * N
rec_stack = [False] * N
for node in range(N):
if not visited[node]:
if dfs(node, visited, rec_stack):
return True
return False
グリッドBFS
code△
from collections import deque
INF = float('inf')
dir = [(1,0),(-1,0),(0,1),(0,-1)]
def inn(x,y): return 0<=x<H and 0<=y<W
def bfs(x,y):
dist = [[INF for _ in range(W)] for _ in range(H)]
dist[x][y] = 0
q = deque({(x,y)})
while q:
x,y = q.popleft()
for dx,dy in dir:
xx=x+dx; yy=y+dy
if inn(xx,yy) and not A[xx][yy]:
if dist[xx][yy] > dist[x][y] + 1:
dist[xx][yy] = dist[x][y] + 1
q.append((xx,yy))
return dist
二部グラフ
code△
def dfs(x,col):
# 始点xからdfsをして,各色で何回塗ったかをcntで返す
# 二部グラフでないときもcntを返すので注意
colors[x] = col
cnt[(col+1)//2]+=1
for v in edge[x]:
if colors[v]==col: return False
if colors[v]==0 and not dfs(v,-col): return False
return True
UnionFind
PythonでのUnion-Find(素集合データ構造)の実装と使い方
UnionFind
code〇
from collections import defaultdict
class UnionFind():
def __init__(self, n):
self.n = n
self.parents = [-1] * n
def find(self, x):
if self.parents[x] < 0:
return x
else:
self.parents[x] = self.find(self.parents[x])
return self.parents[x]
def union(self, x, y):
x = self.find(x)
y = self.find(y)
if x == y:
return
if self.parents[x] > self.parents[y]:
x, y = y, x
self.parents[x] += self.parents[y]
self.parents[y] = x
def size(self, x):
return -self.parents[self.find(x)]
def same(self, x, y):
return self.find(x) == self.find(y)
def members(self, x):
root = self.find(x)
return [i for i in range(self.n) if self.find(i) == root]
def roots(self):
return [i for i, x in enumerate(self.parents) if x < 0]
def group_count(self):
return len(self.roots())
def all_group_members(self):
group_members = defaultdict(list)
for member in range(self.n):
group_members[self.find(member)].append(member)
return group_members
# uf = UnionFind(N)
UnionFindLabel
code〇
class UnionFindLabel(UnionFind):
def __init__(self, labels):
assert len(labels) == len(set(labels))
self.n = len(labels)
self.parents = [-1] * self.n
self.d = {x: i for i, x in enumerate(labels)}
self.d_inv = {i: x for i, x in enumerate(labels)}
def find_label(self, x):
return self.d_inv[super().find(self.d[x])]
def union(self, x, y):
super().union(self.d[x], self.d[y])
def size(self, x):
return super().size(self.d[x])
def same(self, x, y):
return super().same(self.d[x], self.d[y])
def members(self, x):
root = self.find(self.d[x])
return [self.d_inv[i] for i in range(self.n) if self.find(i) == root]
def roots(self):
return [self.d_inv[i] for i, x in enumerate(self.parents) if x < 0]
def all_group_members(self):
group_members = defaultdict(list)
for member in range(self.n):
group_members[self.d_inv[self.find(member)]].append(self.d_inv[member])
return group_members
Weighted UnionFInd
code〇
class WeightedUnionFind:
def __init__(self, n):
self.par = [i for i in range(n+1)]
self.rank = [0] * (n+1)
# 根への距離を管理
self.weight = [0] * (n+1)
# 検索
def find(self, x):
if self.par[x] == x:
return x
else:
y = self.find(self.par[x])
# 親への重みを追加しながら根まで走査
self.weight[x] += self.weight[self.par[x]]
self.par[x] = y
return y
# 併合
def union(self, x, y, w):
rx = self.find(x)
ry = self.find(y)
# xの木の高さ < yの木の高さ
if self.rank[rx] < self.rank[ry]:
self.par[rx] = ry
self.weight[rx] = w - self.weight[x] + self.weight[y]
# xの木の高さ ≧ yの木の高さ
else:
self.par[ry] = rx
self.weight[ry] = -w - self.weight[y] + self.weight[x]
# 木の高さが同じだった場合の処理
if self.rank[rx] == self.rank[ry]:
self.rank[rx] += 1
# 同じ集合に属するか
def same(self, x, y):
return self.find(x) == self.find(y)
# xからyへのコスト
def diff(self, x, y):
return self.weight[x] - self.weight[y]
SortedSet
SortedSet
code〇
# https://github.com/tatyam-prime/SortedSet/blob/main/SortedSet.py
import math
from bisect import bisect_left, bisect_right
from typing import Generic, Iterable, Iterator, List, Tuple, TypeVar, Optional
T = TypeVar('T')
class SortedSet(Generic[T]):
BUCKET_RATIO = 16
SPLIT_RATIO = 24
def __init__(self, a: Iterable[T] = []) -> None:
"Make a new SortedSet from iterable. / O(N) if sorted and unique / O(N log N)"
a = list(a)
n = len(a)
if any(a[i] > a[i + 1] for i in range(n - 1)):
a.sort()
if any(a[i] >= a[i + 1] for i in range(n - 1)):
a, b = [], a
for x in b:
if not a or a[-1] != x:
a.append(x)
n = self.size = len(a)
num_bucket = int(math.ceil(math.sqrt(n / self.BUCKET_RATIO)))
self.a = [a[n * i // num_bucket : n * (i + 1) // num_bucket] for i in range(num_bucket)]
def __iter__(self) -> Iterator[T]:
for i in self.a:
for j in i: yield j
def __reversed__(self) -> Iterator[T]:
for i in reversed(self.a):
for j in reversed(i): yield j
def __eq__(self, other) -> bool:
return list(self) == list(other)
def __len__(self) -> int:
return self.size
def __repr__(self) -> str:
return "SortedSet" + str(self.a)
def __str__(self) -> str:
s = str(list(self))
return "{" + s[1 : len(s) - 1] + "}"
def _position(self, x: T) -> Tuple[List[T], int, int]:
"return the bucket, index of the bucket and position in which x should be. self must not be empty."
for i, a in enumerate(self.a):
if x <= a[-1]: break
return (a, i, bisect_left(a, x))
def __contains__(self, x: T) -> bool:
if self.size == 0: return False
a, _, i = self._position(x)
return i != len(a) and a[i] == x
def add(self, x: T) -> bool:
"Add an element and return True if added. / O(√N)"
if self.size == 0:
self.a = [[x]]
self.size = 1
return True
a, b, i = self._position(x)
if i != len(a) and a[i] == x: return False
a.insert(i, x)
self.size += 1
if len(a) > len(self.a) * self.SPLIT_RATIO:
mid = len(a) >> 1
self.a[b:b+1] = [a[:mid], a[mid:]]
return True
def _pop(self, a: List[T], b: int, i: int) -> T:
ans = a.pop(i)
self.size -= 1
if not a: del self.a[b]
return ans
def discard(self, x: T) -> bool:
"Remove an element and return True if removed. / O(√N)"
if self.size == 0: return False
a, b, i = self._position(x)
if i == len(a) or a[i] != x: return False
self._pop(a, b, i)
return True
def lt(self, x: T) -> Optional[T]:
"Find the largest element < x, or None if it doesn't exist."
for a in reversed(self.a):
if a[0] < x:
return a[bisect_left(a, x) - 1]
def le(self, x: T) -> Optional[T]:
"Find the largest element <= x, or None if it doesn't exist."
for a in reversed(self.a):
if a[0] <= x:
return a[bisect_right(a, x) - 1]
def gt(self, x: T) -> Optional[T]:
"Find the smallest element > x, or None if it doesn't exist."
for a in self.a:
if a[-1] > x:
return a[bisect_right(a, x)]
def ge(self, x: T) -> Optional[T]:
"Find the smallest element >= x, or None if it doesn't exist."
for a in self.a:
if a[-1] >= x:
return a[bisect_left(a, x)]
def __getitem__(self, i: int) -> T:
"Return the i-th element."
if i < 0:
for a in reversed(self.a):
i += len(a)
if i >= 0: return a[i]
else:
for a in self.a:
if i < len(a): return a[i]
i -= len(a)
raise IndexError
def pop(self, i: int = -1) -> T:
"Pop and return the i-th element."
if i < 0:
for b, a in enumerate(reversed(self.a)):
i += len(a)
if i >= 0: return self._pop(a, ~b, i)
else:
for b, a in enumerate(self.a):
if i < len(a): return self._pop(a, b, i)
i -= len(a)
raise IndexError
def index(self, x: T) -> int:
"Count the number of elements < x."
ans = 0
for a in self.a:
if a[-1] >= x:
return ans + bisect_left(a, x)
ans += len(a)
return ans
def index_right(self, x: T) -> int:
"Count the number of elements <= x."
ans = 0
for a in self.a:
if a[-1] > x:
return ans + bisect_right(a, x)
ans += len(a)
return ans
SortedMultiSet
code〇
# https://github.com/tatyam-prime/SortedSet/blob/main/SortedMultiset.py
import math
from bisect import bisect_left, bisect_right
from typing import Generic, Iterable, Iterator, List, Tuple, TypeVar, Optional
T = TypeVar('T')
class SortedMultiset(Generic[T]):
BUCKET_RATIO = 16
SPLIT_RATIO = 24
def __init__(self, a: Iterable[T] = []) -> None:
"Make a new SortedMultiset from iterable. / O(N) if sorted / O(N log N)"
a = list(a)
n = self.size = len(a)
if any(a[i] > a[i + 1] for i in range(n - 1)):
a.sort()
num_bucket = int(math.ceil(math.sqrt(n / self.BUCKET_RATIO)))
self.a = [a[n * i // num_bucket : n * (i + 1) // num_bucket] for i in range(num_bucket)]
def __iter__(self) -> Iterator[T]:
for i in self.a:
for j in i: yield j
def __reversed__(self) -> Iterator[T]:
for i in reversed(self.a):
for j in reversed(i): yield j
def __eq__(self, other) -> bool:
return list(self) == list(other)
def __len__(self) -> int:
return self.size
def __repr__(self) -> str:
return "SortedMultiset" + str(self.a)
def __str__(self) -> str:
s = str(list(self))
return "{" + s[1 : len(s) - 1] + "}"
def _position(self, x: T) -> Tuple[List[T], int, int]:
"return the bucket, index of the bucket and position in which x should be. self must not be empty."
for i, a in enumerate(self.a):
if x <= a[-1]: break
return (a, i, bisect_left(a, x))
def __contains__(self, x: T) -> bool:
if self.size == 0: return False
a, _, i = self._position(x)
return i != len(a) and a[i] == x
def count(self, x: T) -> int:
"Count the number of x."
return self.index_right(x) - self.index(x)
def add(self, x: T) -> None:
"Add an element. / O(√N)"
if self.size == 0:
self.a = [[x]]
self.size = 1
return
a, b, i = self._position(x)
a.insert(i, x)
self.size += 1
if len(a) > len(self.a) * self.SPLIT_RATIO:
mid = len(a) >> 1
self.a[b:b+1] = [a[:mid], a[mid:]]
def _pop(self, a: List[T], b: int, i: int) -> T:
ans = a.pop(i)
self.size -= 1
if not a: del self.a[b]
return ans
def discard(self, x: T) -> bool:
"Remove an element and return True if removed. / O(√N)"
if self.size == 0: return False
a, b, i = self._position(x)
if i == len(a) or a[i] != x: return False
self._pop(a, b, i)
return True
def lt(self, x: T) -> Optional[T]:
"Find the largest element < x, or None if it doesn't exist."
for a in reversed(self.a):
if a[0] < x:
return a[bisect_left(a, x) - 1]
def le(self, x: T) -> Optional[T]:
"Find the largest element <= x, or None if it doesn't exist."
for a in reversed(self.a):
if a[0] <= x:
return a[bisect_right(a, x) - 1]
def gt(self, x: T) -> Optional[T]:
"Find the smallest element > x, or None if it doesn't exist."
for a in self.a:
if a[-1] > x:
return a[bisect_right(a, x)]
def ge(self, x: T) -> Optional[T]:
"Find the smallest element >= x, or None if it doesn't exist."
for a in self.a:
if a[-1] >= x:
return a[bisect_left(a, x)]
def __getitem__(self, i: int) -> T:
"Return the i-th element."
if i < 0:
for a in reversed(self.a):
i += len(a)
if i >= 0: return a[i]
else:
for a in self.a:
if i < len(a): return a[i]
i -= len(a)
raise IndexError
def pop(self, i: int = -1) -> T:
"Pop and return the i-th element."
if i < 0:
for b, a in enumerate(reversed(self.a)):
i += len(a)
if i >= 0: return self._pop(a, ~b, i)
else:
for b, a in enumerate(self.a):
if i < len(a): return self._pop(a, b, i)
i -= len(a)
raise IndexError
def index(self, x: T) -> int:
"Count the number of elements < x."
ans = 0
for a in self.a:
if a[-1] >= x:
return ans + bisect_left(a, x)
ans += len(a)
return ans
def index_right(self, x: T) -> int:
"Count the number of elements <= x."
ans = 0
for a in self.a:
if a[-1] > x:
return ans + bisect_right(a, x)
ans += len(a)
return ans
Segment Tree
func_list
code〇
L = [0]*N
seg = SegTree(L, lambda x,y:min(x,y), float('inf'))
seg = SegTree(L, lambda x,y:max(x,y), -float('inf'))
seg = SegTree(L, lambda x,y:x+y, 0)
seg = SegTree(L, lambda x,y:x*y, 1)
seg = SegTree(L, lambda x,y:math.gcd, 0)
Segment Tree
一点更新、一点加算、区間和 他いろいろ
code〇
def segfunc(x,y):
return x+y
class SegTree:
def __init__(self,init_val,segfunc,ide_ele):
n = len(init_val)
self.segfunc = segfunc
self.ide_ele = ide_ele
self.num = 1<<(n-1).bit_length()
self.tree = [ide_ele]*2*self.num
for i in range(n):
self.tree[self.num+i] = init_val[i]
for i in range(self.num-1,0,-1):
self.tree[i] = self.segfunc(self.tree[2*i],self.tree[2*i+1])
def add(self,k,x):
k += self.num
self.tree[k] += x
while k>1:
self.tree[k>>1] = self.segfunc(self.tree[k],self.tree[k^1])
k >>= 1
def update(self,k,x):
k += self.num
self.tree[k] = x
while k>1:
self.tree[k>>1] = self.segfunc(self.tree[k],self.tree[k^1])
k >>= 1
def query(self,l,r):
res = self.ide_ele
l += self.num
r += self.num
while l<r:
if l&1:
res = self.segfunc(res,self.tree[l])
l += 1
if r&1:
res = self.segfunc(res,self.tree[r-1])
l >>= 1
r >>= 1
return res
Fenwick Tree
一点加算,区間和
ABC276
code〇
class Fenwick_Tree:
def __init__(self, n):
self._n = n
self.data = [0] * n
def add(self, p, x):
assert 0 <= p < self._n
p += 1
while p <= self._n:
self.data[p - 1] += x
p += p & -p
def sum(self, l, r):
assert 0 <= l <= r <= self._n
return self._sum(r) - self._sum(l)
def _sum(self, r):
s = 0
while r > 0:
s += self.data[r - 1]
r -= r & -r
return s
Lazy Segment Tree(区間加算)
code〇
def segfunc(x,y):
return x+y
class LazySegTree_RAQ:
def __init__(self,init_val,segfunc,ide_ele):
n = len(init_val)
self.segfunc = segfunc
self.ide_ele = ide_ele
self.num = 1<<(n-1).bit_length()
self.tree = [ide_ele]*2*self.num
self.lazy = [0]*2*self.num
for i in range(n):
self.tree[self.num+i] = init_val[i]
for i in range(self.num-1,0,-1):
self.tree[i] = self.segfunc(self.tree[2*i], self.tree[2*i+1])
def gindex(self,l,r):
l += self.num
r += self.num
lm = l>>(l&-l).bit_length()
rm = r>>(r&-r).bit_length()
while r>l:
if l<=lm:
yield l
if r<=rm:
yield r
r >>= 1
l >>= 1
while l:
yield l
l >>= 1
def propagates(self,*ids):
for i in reversed(ids):
v = self.lazy[i]
if v==0:
continue
self.lazy[i] = 0
self.lazy[2*i] += v
self.lazy[2*i+1] += v
self.tree[2*i] += v
self.tree[2*i+1] += v
def add(self,l,r,x):
ids = self.gindex(l,r)
l += self.num
r += self.num
while l<r:
if l&1:
self.lazy[l] += x
self.tree[l] += x
l += 1
if r&1:
self.lazy[r-1] += x
self.tree[r-1] += x
r >>= 1
l >>= 1
for i in ids:
self.tree[i] = self.segfunc(self.tree[2*i], self.tree[2*i+1]) + self.lazy[i]
def query(self,l,r):
self.propagates(*self.gindex(l,r))
res = self.ide_ele
l += self.num
r += self.num
while l<r:
if l&1:
res = self.segfunc(res,self.tree[l])
l += 1
if r&1:
res = self.segfunc(res,self.tree[r-1])
l >>= 1
r >>= 1
return res
Lazy Segment Tree(区間更新)
code〇
def segfunc(x,y):
return min(x,y)
class LazySegTree_RUQ:
def __init__(self,init_val,segfunc,ide_ele):
n = len(init_val)
self.segfunc = segfunc
self.ide_ele = ide_ele
self.num = 1<<(n-1).bit_length()
self.tree = [ide_ele]*2*self.num
self.lazy = [None]*2*self.num
for i in range(n):
self.tree[self.num+i] = init_val[i]
for i in range(self.num-1,0,-1):
self.tree[i] = self.segfunc(self.tree[2*i],self.tree[2*i+1])
def gindex(self,l,r):
l += self.num
r += self.num
lm = l>>(l&-l).bit_length()
rm = r>>(r&-r).bit_length()
while r>l:
if l<=lm:
yield l
if r<=rm:
yield r
r >>= 1
l >>= 1
while l:
yield l
l >>= 1
def propagates(self,*ids):
for i in reversed(ids):
v = self.lazy[i]
if v is None:
continue
self.lazy[i] = None
self.lazy[2*i] = v
self.lazy[2*i+1] = v
self.tree[2*i] = v
self.tree[2*i+1] = v
def update(self,l,r,x):
ids = self.gindex(l,r)
self.propagates(*self.gindex(l,r))
l += self.num
r += self.num
while l<r:
if l&1:
self.lazy[l] = x
self.tree[l] = x
l += 1
if r&1:
self.lazy[r-1] = x
self.tree[r-1] = x
r >>= 1
l >>= 1
for i in ids:
self.tree[i] = self.segfunc(self.tree[2*i], self.tree[2*i+1])
def query(self,l,r):
ids = self.gindex(l,r)
self.propagates(*self.gindex(l,r))
res = self.ide_ele
l += self.num
r += self.num
while l<r:
if l&1:
res = self.segfunc(res,self.tree[l])
l += 1
if r&1:
res = self.segfunc(res,self.tree[r-1])
l >>= 1
r >>= 1
return res
LCA
LCA
最小共通祖先とその距離を求める
ABC014
code〇
from collections import deque
import sys
sys.setrecursionlimit(10 ** 8)
class LCA:
def __init__(self, G, root=0):
V = len(G)
K = 1
while 1<<K < V:
K += 1
self.parent = [[-1]*V for _ in range(K)]
self.dist = [-1]*V
self.bfs(G, root)
for i in range(K-1):
for j in range(V):
if self.parent[i][j] != -1:
self.parent[i+1][j] = self.parent[i][self.parent[i][j]]
def bfs(self, G, root):
self.dist[root] = 0
que = deque()
que.append(root)
while que:
n = que.popleft()
for v in G[n]:
if self.dist[v] == -1:
self.dist[v] = self.dist[n]+1
self.parent[0][v] = n
que.append(v)
def query(self, a, b):
if self.dist[a] < self.dist[b]:
a, b = b, a
K = len(self.parent)
for i in range(K):
if (self.dist[a]-self.dist[b]) & 1<<i:
a = self.parent[i][a]
if a == b:
return a
for i in reversed(range(K)):
if self.parent[i][a] != self.parent[i][b]:
a = self.parent[i][a]
b = self.parent[i][b]
return self.parent[0][a]
パスの最大重みを求める
ダブリングと呼ぶらしい
めちゃマイナー
ABC235
code〇
class SecondMinimumSpanningTreeLight:
def __init__(self, graph):
self.n = len(graph)
self.parent = [-1] * self.n
self.depth = [-1] * self.n
stack = deque([0])
self.depth[0] = 0
while stack:
i = stack.popleft()
for j in graph[i]:
if self.depth[j] == -1:
self.depth[j] = self.depth[i] + 1
self.parent[j] = i
stack.append(j)
self.cols = max(2, math.ceil(math.log2(self.n)))
self.dp = [-1] * self.cols * self.n
self.maximum_weight = [-1] * self.cols * self.n
for i in range(self.n):
self.dp[i * self.cols] = self.parent[i]
if self.parent[i] != -1:
self.maximum_weight[i * self.cols] = graph[self.parent[i]][i]
for j in range(1, self.cols):
for i in range(self.n):
ancestor = self.dp[i * self.cols + j - 1]
max_weight = self.maximum_weight[i * self.cols + j - 1]
self.maximum_weight[i * self.cols + j] = max_weight
if ancestor != -1:
self.dp[i * self.cols + j] = self.dp[ancestor * self.cols + j - 1]
max_weight = max(self.maximum_weight[i * self.cols + j], self.maximum_weight[ancestor * self.cols + j - 1])
self.maximum_weight[i * self.cols + j] = max_weight
def get_maximum_weight(self, x, y):
if self.depth[x] < self.depth[y]:
x, y = y, x
max_weight = -1
while self.depth[x] > self.depth[y]:
diff = self.depth[x] - self.depth[y]
max_weight = max(max_weight, self.maximum_weight[x * self.cols + int(math.log2(diff))])
x = self.dp[x * self.cols + int(math.log2(diff))]
if x == y:
return max_weight
graph = [dict() for _ in range(n)]
for a, b, c in edges:
a -= 1
b -= 1
graph[a][b] = c
graph[b][a] = c
mst = SecondMinimumSpanningTreeLight(graph)
二分探索
bisect
code〇
import bisect
def index(a, x): # 探索したい数値のindexを探索
'Locate the leftmost value exactly equal to x'
i = bisect.bisect_left(a, x)
if i != len(a) and a[i] == x:
return i
else:return -1
def find_lt(a, x): # 探索したい数値未満のうち最大の数値を探索
'Find rightmost value less than x'
i = bisect.bisect_left(a, x)
if i:
return a[i-1]
else:return -1
def find_le(a, x): # 探索したい数値以下のうち最大の数値を探索
'Find rightmost value less than or equal to x'
i = bisect.bisect_right(a, x)
if i:
return a[i-1]
else:return -1
def find_gt(a, x): # 探索したい数値を超えるもののうち最小の数値を探索
'Find leftmost value greater than x'
i = bisect.bisect_right(a, x)
if i != len(a):
return a[i]
else:return -1
def find_ge(a, x): # 探索したい数値以上のうち最小の数値を探索
'Find leftmost item greater than or equal to x'
i = bisect.bisect_left(a, x)
if i != len(a):
return a[i]
else:return -1
めぐる式二分探索
code〇
def is_ok(arg):
# 条件を満たすかどうか?問題ごとに定義
pass
def meguru_bisect(ng, ok):
'''
初期値のng,okを受け取り,is_okを満たす最小(最大)のokを返す
まずis_okを定義すべし
ng ok は とり得る最小の値-1 とり得る最大の値+1
最大最小が逆の場合はよしなにひっくり返す
'''
while (abs(ok - ng) > 1):
mid = (ok + ng) // 2
if is_ok(mid):
ok = mid
else:
ng = mid
return ok
dp
ナップザック
yaketake08’s 実装メモ(ナップサック問題 (重みが小さいケース))
yaketake08’s 実装メモ(ナップサック問題 (価値が小さいケース))
- $N$ 種類の品物がある
- $i$ 番目の品物の価値は $vi$, 容量は $wi$, 個数は $ci$
- 重さの総和 $W$ まで入るナップサックに入れる
- ナップサックに入る品物の価値を最大化する
0-1ナップザック:重みが小さいケース
$c_i = 1$
計算量は$O(NW)$
code〇
# i番目の重みws[i], 価値vs[i]
def solve(N, W, ws, vs):
dp = [0] * (W+1)
for i in range(N):
# 価値v, 重さw
v = vs[i]; w = ws[i]
for j in range(W, w-1, -1):
dp[j] = max(dp[j-w] + v, dp[j])
return max(dp)
個数制限なしナップザック:重みが小さいケース
$c_i$は無制限
計算量は$O(NW)$
code〇
# i番目の重みws[i], 価値vs[i]
def solve(N, W, ws, vs):
dp = [0] * (W+1)
for i in range(N):
# 価値v, 重さw
v = vs[i]; w = ws[i]
for j in range(w, W+1):
dp[j] = max(dp[j-w] + v, dp[j])
return max(dp)
0-1ナップザック:価値が小さいケース
計算量は$O(N^2\max_i{v_i})$
code〇
# i番目の重みws[i], 価値vs[i]
# i番目の重みws[i], 価値vs[i]
def solve(N, W, ws, vs):
# V = 全ての品物の価値の総和
V = sum(vs)
# 初期値は価値0以外の重さを上限より大きく
dp = [W+1] * (V + 1)
dp[0] = 0
for i in range(N):
# 価値v, 重さw
v = vs[i]; w = ws[i]
for j in range(V, v-1, -1):
dp[j] = min(dp[j-v] + w, dp[j])
# 重さが上限以下の価値のうち、最大の価値が解
return max(i for i in range(V+1) if dp[i] <= W)
巡回セールスマン問題
code〇
def tsp(dist,N):
"""
0 ~ N-3: 街
N-2: 始点
N-1: 終点
dist[i][j]: i ~ j の移動コスト
dp[S][i]: 訪問済み街の集合:S, 現在いる街:i
O(N * 2^N)
"""
n = N - 2
dp = [[INF for _ in range(n)] for _ in range(1<<n)]
for i in range(n):
dp[1<<i][i] = dist[-2][i]
for mask in range(1 << n):
for u in range(n):
if mask & (1 << u):
for v in range(n):
if not (mask & (1 << v)):
new_mask = mask | (1 << v)
dp[new_mask][v] = min(dp[new_mask][v], dp[mask][u] + dist[u][v])
return dp
その他
マージソート
code×
code
転倒数
code〇
# 0-indexの数列で使える
def inversion(inds,N):
bit = [0] * (N+1)
def bit_add(x,w):
while x <= N:
bit[x] += w
x += (x & -x)
def bit_sum(x):
ret = 0
while x > 0:
ret += bit[x]
x -= (x & -x)
return ret
inv = 0
for ind in reversed(inds):
inv += bit_sum(ind + 1)
bit_add(ind + 1, 1)
return inv
# 任意の配列で使えるバージョン
def inversion(in_arr):
A,B = in_arr, sorted(in_arr)
arr = []
dic = {}
for i,b in enumerate(B):
if b in dic:
arr[dic[b]].append(i)
else:
dic[b] = len(arr)
arr.append([i])
for i in range(len(arr)):
arr[i].reverse()
inds = []
for a in A:
inds.append(arr[dic[a]].pop())
L = len(inds)
bit = [0] * (L+1)
def bit_add(x,w):
while x <= L:
bit[x] += w
x += (x & -x)
def bit_sum(x):
ret = 0
while x > 0:
ret += bit[x]
x -= (x & -x)
return ret
inv = 0
for ind in reversed(inds):
inv += bit_sum(ind + 1)
bit_add(ind + 1, 1)
return inv
素因数分解
yaketake08’s 実装メモ(素因数分解 (試し割り法))
code〇
# N: 素因数分解する数
def factorization(N):
res = []
x = N
y = 2
while y*y <= x:
while x % y == 0:
res.append(y)
x //= y
y += 1
if x > 1:
res.append(x)
return res
素数テーブル作成
code×
code
拡張ユークリッド
yaketake08’s 実装メモ(最大公約数 ((拡張)ユークリッドの互除法))
code○
# Euclidean Algorithm
def gcd(m, n):
r = m % n
return gcd(n, r) if r else n
# Euclidean Algorithm (non-recursive)
def gcd2(m, n):
while n:
m, n = n, m % n
return m
# Extended Euclidean Algorithm
def extgcd(a, b):
if b:
d, y, x = extgcd(b, a % b)
y -= (a // b)*x
return d, x, y
return a, 1, 0
# lcm (least common multiple)
def lcm(m, n):
return m//gcd(m, n)*n
zアルゴリズム
SとS[i:]の最長共通接頭辞の長さのリストを返す
ABC284
z-algorithm
code〇
def z(S):
N = len(S)
A = [0 for _ in range(N)]
# A[i] : SとS[i:]の最長共通接頭辞の長さ
i=1; j=0
while i<N:
while i+j<N and S[j]==S[i+j]: j+=1
A[i] = j
if j==0: i+=1; continue
k = 1
while i+k<N and k+A[k]<j: A[i+k] = A[k]; k+=1
i+=k; j-=k
return A
最長共通部分列
yaketake08’s 実装メモ(最長共通部分列)
計算量は$O(NM)$
code
def solve(S, T):
L1 = len(S)
L2 = len(T)
dp = [[0]*(L2+1) for i in range(L1+1)]
for i in range(L1-1, -1, -1):
for j in range(L2-1, -1, -1):
r = max(dp[i+1][j], dp[i][j+1])
if S[i] == T[j]:
r = max(r, dp[i+1][j+1] + 1)
dp[i][j] = r
# dp[0][0] が長さの解
# ここからは復元処理
res = []
i = 0; j = 0
while i < L1 and j < L2:
if S[i] == T[j]:
res.append(S[i])
i += 1; j += 1
elif dp[i][j] == dp[i+1][j]:
i += 1
elif dp[i][j] == dp[i][j+1]:
j += 1
return "".join(res)
最長回文
code×
code
幾何
yaketake08’s 実装メモ
リンク略
-
点の線対称
-
多角形の面積
-
多角形の点包含判定
-
線分同士の交差判定
-
線分と頂点の最短距離
-
直線同士の交点
-
直線(線分)と円の交点
-
円同士の交点
-
2つの円の共通部分の面積
-
円の共通接線の接点/接線
-
三角形の外接円/内接円/傍接円
-
凸包 (Graham Scan)
-
凸多角形の点包含判定
-
凸多角形同士の交差判定/交点
-
直線による凸多角形の切断
-
凸多角形と直線の交差判定/交点
-
点から凸多角形への接線
-
最遠点対 (キャリパー法)
-
最近点対 (分割統治法)
-
座標圧縮