はじめに
昨年Atcoderを始めた大学院生です.
Pythonのライブラリを自分用にまとめました.
コンテンツは今後増やすつもりです.
読んでくださる方へ
Qiita新参者のライブラリなんて怖くて使えないと思うので
なるべく他の方の記事や自身のACコードを添付するつもりです.
参考にする際は自己責任でお願いします.
ライブラリを使った練習問題はこちらで検索してみてください.
ダイクストラ
始点Sから各頂点への最短距離を求める.
ACコード(ABC340D)
import heapq
def dijkstra(S,N,edge):
# 始点,頂点数,辺集合(edge[u] = [v,w])を入力として
# 各頂点の最小コストのリストを返す
hq = [(0, S)]
heapq.heapify(hq)
cost = [float('inf')] * N
cost[S] = 0
while hq:
c, v = heapq.heappop(hq)
if c > cost[v]:continue
for u, w in edge[v]:
tmp = w + cost[v]
if tmp < cost[u]:
cost[u] = tmp
heapq.heappush(hq, (tmp, u))
return cost
すべての辺の重みが1のときはBFSになる.
グリッドBFS
各成分の頂点数,辺数を求める
二部グラフ
def dfs(x,col):
# 始点xからdfsをして,各色で何回塗ったかをcntで返す
# 二部グラフでないときもcntを返すので注意
colors[x] = col
cnt[(col+1)//2]+=1
for v in edge[x]:
if colors[v]==col: return False
if colors[v]==0 and not dfs(v,-col): return False
return True
強連結成分分解
トポロジカルソート
サイクル検出
クラスカル法
プリム法
Fenwick
マージソート
bisect
import bisect
def index(a, x): # 探索したい数値のindexを探索
'Locate the leftmost value exactly equal to x'
i = bisect.bisect_left(a, x)
if i != len(a) and a[i] == x:
return i
else:return -1
def find_lt(a, x): # 探索したい数値未満のうち最大の数値を探索
'Find rightmost value less than x'
i = bisect.bisect_left(a, x)
if i:
return a[i-1]
else:return -1
def find_le(a, x): # 探索したい数値以下のうち最大の数値を探索
'Find rightmost value less than or equal to x'
i = bisect.bisect_right(a, x)
if i:
return a[i-1]
else:return -1
def find_gt(a, x): # 探索したい数値を超えるもののうち最小の数値を探索
'Find leftmost value greater than x'
i = bisect.bisect_right(a, x)
if i != len(a):
return a[i]
else:return -1
def find_ge(a, x): # 探索したい数値以上のうち最小の数値を探索
'Find leftmost item greater than or equal to x'
i = bisect.bisect_left(a, x)
if i != len(a):
return a[i]
else:return -1
めぐる式二分探索
def is_ok(arg):
# 条件を満たすかどうか?問題ごとに定義
pass
def meguru_bisect(ng, ok):
'''
初期値のng,okを受け取り,is_okを満たす最小(最大)のokを返す
まずis_okを定義すべし
ng ok は とり得る最小の値-1 とり得る最大の値+1
最大最小が逆の場合はよしなにひっくり返す
'''
while (abs(ok - ng) > 1):
mid = (ok + ng) // 2
if is_ok(mid):
ok = mid
else:
ng = mid
return ok
セグ木
参考記事を見てください
例) セグ木(区間最小)
#####segfunc#####
def segfunc(x, y):
return min(x, y)
#################
#####ide_ele#####
ide_ele = float('inf')
#################
class SegTree:
"""
init(init_val, ide_ele): 配列init_valで初期化 O(N)
update(k, x): k番目の値をxに更新 O(N)
query(l, r): 区間[l, r)をsegfuncしたものを返す O(logN)
"""
def __init__(self, init_val, segfunc, ide_ele):
"""
init_val: 配列の初期値
segfunc: 区間にしたい操作
ide_ele: 単位元
n: 要素数
num: n以上の最小の2のべき乗
tree: セグメント木(1-index)
"""
n = len(init_val)
self.segfunc = segfunc
self.ide_ele = ide_ele
self.num = 1 << (n - 1).bit_length()
self.tree = [ide_ele] * 2 * self.num
# 配列の値を葉にセット
for i in range(n):
self.tree[self.num + i] = init_val[i]
# 構築していく
for i in range(self.num - 1, 0, -1):
self.tree[i] = self.segfunc(self.tree[2 * i], self.tree[2 * i + 1])
def update(self, k, x):
"""
k番目の値をxに更新
k: index(0-index)
x: update value
"""
k += self.num
self.tree[k] = x
while k > 1:
self.tree[k >> 1] = self.segfunc(self.tree[k], self.tree[k ^ 1])
k >>= 1
def query(self, l, r):
"""
[l, r)のsegfuncしたものを得る
l: index(0-index)
r: index(0-index)
"""
res = self.ide_ele
l += self.num
r += self.num
while l < r:
if l & 1:
res = self.segfunc(res, self.tree[l])
l += 1
if r & 1:
res = self.segfunc(res, self.tree[r - 1])
l >>= 1
r >>= 1
return res
a = [14, 5, 9, 13, 7, 12, 11, 1, 7, 8]
seg = SegTree(a, segfunc, ide_ele)
print(seg.query(0, 8))
seg.update(5, 0)
print(seg.query(0, 8))
デバッグ用コード
def check(r,l):
print([seg.query(i,i+1) for i in range(r,l)])
遅延セグ木
参考記事を見てください
下の例のクエリは区間和を求めます.
例) 遅延セグ木(区間加算)
def segfunc(x,y):
return x+y
class LazySegTree_RAQ:
def __init__(self,init_val,segfunc,ide_ele):
n = len(init_val)
self.segfunc = segfunc
self.ide_ele = ide_ele
self.num = 1<<(n-1).bit_length()
self.tree = [ide_ele]*2*self.num
self.lazy = [0]*2*self.num
for i in range(n):
self.tree[self.num+i] = init_val[i]
for i in range(self.num-1,0,-1):
self.tree[i] = self.segfunc(self.tree[2*i], self.tree[2*i+1])
def gindex(self,l,r):
l += self.num
r += self.num
lm = l>>(l&-l).bit_length()
rm = r>>(r&-r).bit_length()
while r>l:
if l<=lm:
yield l
if r<=rm:
yield r
r >>= 1
l >>= 1
while l:
yield l
l >>= 1
def propagates(self,*ids):
for i in reversed(ids):
v = self.lazy[i]
if v==0:
continue
self.lazy[i] = 0
self.lazy[2*i] += v
self.lazy[2*i+1] += v
self.tree[2*i] += v
self.tree[2*i+1] += v
def add(self,l,r,x):
ids = self.gindex(l,r)
l += self.num
r += self.num
while l<r:
if l&1:
self.lazy[l] += x
self.tree[l] += x
l += 1
if r&1:
self.lazy[r-1] += x
self.tree[r-1] += x
r >>= 1
l >>= 1
for i in ids:
self.tree[i] = self.segfunc(self.tree[2*i], self.tree[2*i+1]) + self.lazy[i]
def query(self,l,r):
self.propagates(*self.gindex(l,r))
res = self.ide_ele
l += self.num
r += self.num
while l<r:
if l&1:
res = self.segfunc(res,self.tree[l])
l += 1
if r&1:
res = self.segfunc(res,self.tree[r-1])
l >>= 1
r >>= 1
return res
例) 遅延セグ木(区間更新)
def segfunc(x,y):
return min(x,y)
class LazySegTree_RUQ:
def __init__(self,init_val,segfunc,ide_ele):
n = len(init_val)
self.segfunc = segfunc
self.ide_ele = ide_ele
self.num = 1<<(n-1).bit_length()
self.tree = [ide_ele]*2*self.num
self.lazy = [None]*2*self.num
for i in range(n):
self.tree[self.num+i] = init_val[i]
for i in range(self.num-1,0,-1):
self.tree[i] = self.segfunc(self.tree[2*i],self.tree[2*i+1])
def gindex(self,l,r):
l += self.num
r += self.num
lm = l>>(l&-l).bit_length()
rm = r>>(r&-r).bit_length()
while r>l:
if l<=lm:
yield l
if r<=rm:
yield r
r >>= 1
l >>= 1
while l:
yield l
l >>= 1
def propagates(self,*ids):
for i in reversed(ids):
v = self.lazy[i]
if v is None:
continue
self.lazy[i] = None
self.lazy[2*i] = v
self.lazy[2*i+1] = v
self.tree[2*i] = v
self.tree[2*i+1] = v
def update(self,l,r,x):
ids = self.gindex(l,r)
self.propagates(*self.gindex(l,r))
l += self.num
r += self.num
while l<r:
if l&1:
self.lazy[l] = x
self.tree[l] = x
l += 1
if r&1:
self.lazy[r-1] = x
self.tree[r-1] = x
r >>= 1
l >>= 1
for i in ids:
self.tree[i] = self.segfunc(self.tree[2*i], self.tree[2*i+1])
def query(self,l,r):
ids = self.gindex(l,r)
self.propagates(*self.gindex(l,r))
res = self.ide_ele
l += self.num
r += self.num
while l<r:
if l&1:
res = self.segfunc(res,self.tree[l])
l += 1
if r&1:
res = self.segfunc(res,self.tree[r-1])
l >>= 1
r >>= 1
return res
出題:ABC340E
UnionFind
PythonでのUnion-Find(素集合データ構造)の実装と使い方
UnionFind
from collections import defaultdict
class UnionFind():
def __init__(self, n):
self.n = n
self.parents = [-1] * n
def find(self, x):
if self.parents[x] < 0:
return x
else:
self.parents[x] = self.find(self.parents[x])
return self.parents[x]
def union(self, x, y):
x = self.find(x)
y = self.find(y)
if x == y:
return
if self.parents[x] > self.parents[y]:
x, y = y, x
self.parents[x] += self.parents[y]
self.parents[y] = x
def size(self, x):
return -self.parents[self.find(x)]
def same(self, x, y):
return self.find(x) == self.find(y)
def members(self, x):
root = self.find(x)
return [i for i in range(self.n) if self.find(i) == root]
def roots(self):
return [i for i, x in enumerate(self.parents) if x < 0]
def group_count(self):
return len(self.roots())
def all_group_members(self):
group_members = defaultdict(list)
for member in range(self.n):
group_members[self.find(member)].append(member)
return group_members
def __str__(self):
return '\n'.join(f'{r}: {m}' for r, m in self.all_group_members().items())
UnionFindLabel
class UnionFindLabel(UnionFind):
def __init__(self, labels):
assert len(labels) == len(set(labels))
self.n = len(labels)
self.parents = [-1] * self.n
self.d = {x: i for i, x in enumerate(labels)}
self.d_inv = {i: x for i, x in enumerate(labels)}
def find_label(self, x):
return self.d_inv[super().find(self.d[x])]
def union(self, x, y):
super().union(self.d[x], self.d[y])
def size(self, x):
return super().size(self.d[x])
def same(self, x, y):
return super().same(self.d[x], self.d[y])
def members(self, x):
root = self.find(self.d[x])
return [self.d_inv[i] for i in range(self.n) if self.find(i) == root]
def roots(self):
return [self.d_inv[i] for i, x in enumerate(self.parents) if x < 0]
def all_group_members(self):
group_members = defaultdict(list)
for member in range(self.n):
group_members[self.d_inv[self.find(member)]].append(self.d_inv[member])
return group_members
Python:重み付きUnion-Find木について
併合するとき要素間の距離も織り込めるイメージです。
F - Good Set Query
ACコード
UnionFindLabel
class WeightedUnionFind:
def __init__(self, n):
self.par = [i for i in range(n+1)]
self.rank = [0] * (n+1)
# 根への距離を管理
self.weight = [0] * (n+1)
# 検索
def find(self, x):
if self.par[x] == x:
return x
else:
y = self.find(self.par[x])
# 親への重みを追加しながら根まで走査
self.weight[x] += self.weight[self.par[x]]
self.par[x] = y
return y
# 併合
def union(self, x, y, w):
rx = self.find(x)
ry = self.find(y)
# xの木の高さ < yの木の高さ
if self.rank[rx] < self.rank[ry]:
self.par[rx] = ry
self.weight[rx] = w - self.weight[x] + self.weight[y]
# xの木の高さ ≧ yの木の高さ
else:
self.par[ry] = rx
self.weight[ry] = -w - self.weight[y] + self.weight[x]
# 木の高さが同じだった場合の処理
if self.rank[rx] == self.rank[ry]:
self.rank[rx] += 1
# 同じ集合に属するか
def same(self, x, y):
return self.find(x) == self.find(y)
# xからyへのコスト
def diff(self, x, y):
return self.weight[x] - self.weight[y]
ナップサック問題
ナップサック問題はいつもyaketake08’s 実装メモを拝借しています.
W以上で価値を最小化する問題を解く関数. ACコード(ABC317)
def solve(N, W, ws, vs):
# i番目の重み ws[i],価値 vs[i]
# 重み総和が W を超える最小の価値を返す
# O(N*sum(ws))
sumw = sum(ws)
dp = [float('inf')] * (sumw+1)
dp[0] = 0
for i in range(N):
v = vs[i]; w = ws[i]
for j in range(sumw, w-1, -1):
dp[j] = min(dp[j-w] + v, dp[j])
return min(dp[W:])
Sortedset
SortedMultiset
転倒数を求める
ABC232-F
ABC244-D
ABC264-D
ABC332-D
ARC120-C
# 0-indexの数列で使える
def inversion(inds,N):
bit = [0] * (N+1)
def bit_add(x,w):
while x <= N:
bit[x] += w
x += (x & -x)
def bit_sum(x):
ret = 0
while x > 0:
ret += bit[x]
x -= (x & -x)
return ret
inv = 0
for ind in reversed(inds):
inv += bit_sum(ind + 1)
bit_add(ind + 1, 1)
return inv
# 任意の配列で使えるバージョン
def inversion(in_arr):
A,B = in_arr, sorted(in_arr)
arr = []
dic = {}
for i,b in enumerate(B):
if b in dic:
arr[dic[b]].append(i)
else:
dic[b] = len(arr)
arr.append([i])
for i in range(len(arr)):
arr[i].reverse()
inds = []
for a in A:
inds.append(arr[dic[a]].pop())
L = len(inds)
bit = [0] * (L+1)
def bit_add(x,w):
while x <= L:
bit[x] += w
x += (x & -x)
def bit_sum(x):
ret = 0
while x > 0:
ret += bit[x]
x -= (x & -x)
return ret
inv = 0
for ind in reversed(inds):
inv += bit_sum(ind + 1)
bit_add(ind + 1, 1)
return inv
素因数分解
nを素因数分解したリストを返す.
よく使います.計算量はO(√n)
参考記事:8を入れたら[[2,3]]が返ってくる.Countの手間が省けます.
def factorization(n):
arr = []
temp = n
for i in range(2, int(-(-n**0.5//1))+1):
if temp%i==0:
cnt=0
while temp%i==0:
cnt+=1
temp //= i
arr.append([i, cnt])
if temp!=1:
arr.append([temp, 1])
if arr==[]:
arr.append([n, 1])
return arr
1からnまでのxor
参考
https://atcoder.jp/contests/abc121/tasks/abc121_d
def xor_sum(n):
ans = 0
keta = len(bin(n))-2
n += 1
for i in range(keta):
block = 1<<(i+1)
ans += (((n//block) * (block // 2) + max(0,n % block - block // 2)) & 1) << i
return ans
拡張ユークリッド
二次元行列をListのままくるくるする
def reverse(H,W,L):
# 反転
LL = [[0 for _ in range(H)] for _ in range(W)]
for i in range(H):
for j in range(W):
LL[j][i] = L[i][j]
return LL
def rotate(H,W,L):
# 時計回り90度回転
LL = [[0 for _ in range(H)] for _ in range(W)]
for i in range(H):
for j in range(W):
LL[j][-1-i] = L[i][j]
return LL
参考文献
Markdown記法チートシート
Qiita記事作成方法 初心者の備忘録
Qiitaに投稿するときの心構え
競プロpython