1
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Python 自分用のライブラリが欲しい

Last updated at Posted at 2024-02-19

はじめに

昨年Atcoderを始めた大学院生です.
Pythonのライブラリを自分用にまとめました.
コンテンツは今後増やすつもりです.

読んでくださる方へ

Qiita新参者のライブラリなんて怖くて使えないと思うので
なるべく他の方の記事や自身のACコードを添付するつもりです.
参考にする際は自己責任でお願いします.

ライブラリを使った練習問題はこちらで検索してみてください.

ダイクストラ

始点Sから各頂点への最短距離を求める.
ACコード(ABC340D)

import heapq
def dijkstra(S,N,edge):
    # 始点,頂点数,辺集合(edge[u] = [v,w])を入力として
    # 各頂点の最小コストのリストを返す
    hq = [(0, S)]
    heapq.heapify(hq)
    cost = [float('inf')] * N
    cost[S] = 0
    while hq:
        c, v = heapq.heappop(hq)
        if c > cost[v]:continue
        for u, w in edge[v]:
            tmp = w + cost[v]
            if tmp < cost[u]:
                cost[u] = tmp
                heapq.heappush(hq, (tmp, u))
    return cost

すべての辺の重みが1のときはBFSになる.

グリッドBFS

各成分の頂点数,辺数を求める

二部グラフ

def dfs(x,col):
    # 始点xからdfsをして,各色で何回塗ったかをcntで返す
    # 二部グラフでないときもcntを返すので注意
    colors[x] = col
    cnt[(col+1)//2]+=1
    for v in edge[x]:
        if colors[v]==col: return False
        if colors[v]==0 and not dfs(v,-col): return False
    return True

強連結成分分解

トポロジカルソート

サイクル検出

クラスカル法

プリム法

Fenwick

マージソート

bisect

参考記事

import bisect
def index(a, x):   # 探索したい数値のindexを探索
    'Locate the leftmost value exactly equal to x'
    i = bisect.bisect_left(a, x)
    if i != len(a) and a[i] == x:
        return i
    else:return -1
    
def find_lt(a, x):   # 探索したい数値未満のうち最大の数値を探索
    'Find rightmost value less than x'
    i = bisect.bisect_left(a, x)
    if i:
        return a[i-1]
    else:return -1

def find_le(a, x):   # 探索したい数値以下のうち最大の数値を探索
    'Find rightmost value less than or equal to x'
    i = bisect.bisect_right(a, x)
    if i:
        return a[i-1]
    else:return -1

def find_gt(a, x):   # 探索したい数値を超えるもののうち最小の数値を探索
    'Find leftmost value greater than x'
    i = bisect.bisect_right(a, x)
    if i != len(a):
        return a[i]
    else:return -1

def find_ge(a, x):   # 探索したい数値以上のうち最小の数値を探索
    'Find leftmost item greater than or equal to x'
    i = bisect.bisect_left(a, x)
    if i != len(a):
        return a[i]
    else:return -1

めぐる式二分探索

参考記事

def is_ok(arg):
    # 条件を満たすかどうか?問題ごとに定義
    pass
      
def meguru_bisect(ng, ok):
    '''
    初期値のng,okを受け取り,is_okを満たす最小(最大)のokを返す
    まずis_okを定義すべし
    ng ok は  とり得る最小の値-1 とり得る最大の値+1
    最大最小が逆の場合はよしなにひっくり返す
    '''
    while (abs(ok - ng) > 1):
        mid = (ok + ng) // 2
        if is_ok(mid):
            ok = mid
        else:
            ng = mid
    return ok

セグ木

参考記事を見てください

例) セグ木(区間最小)
#####segfunc#####
def segfunc(x, y):
    return min(x, y)
#################

#####ide_ele#####
ide_ele = float('inf')
#################

class SegTree:
    """
    init(init_val, ide_ele): 配列init_valで初期化 O(N)
    update(k, x): k番目の値をxに更新 O(N)
    query(l, r): 区間[l, r)をsegfuncしたものを返す O(logN)
    """
    def __init__(self, init_val, segfunc, ide_ele):
        """
        init_val: 配列の初期値
        segfunc: 区間にしたい操作
        ide_ele: 単位元
        n: 要素数
        num: n以上の最小の2のべき乗
        tree: セグメント木(1-index)
        """
        n = len(init_val)
        self.segfunc = segfunc
        self.ide_ele = ide_ele
        self.num = 1 << (n - 1).bit_length()
        self.tree = [ide_ele] * 2 * self.num
        # 配列の値を葉にセット
        for i in range(n):
            self.tree[self.num + i] = init_val[i]
        # 構築していく
        for i in range(self.num - 1, 0, -1):
            self.tree[i] = self.segfunc(self.tree[2 * i], self.tree[2 * i + 1])

    def update(self, k, x):
        """
        k番目の値をxに更新
        k: index(0-index)
        x: update value
        """
        k += self.num
        self.tree[k] = x
        while k > 1:
            self.tree[k >> 1] = self.segfunc(self.tree[k], self.tree[k ^ 1])
            k >>= 1

    def query(self, l, r):
        """
        [l, r)のsegfuncしたものを得る
        l: index(0-index)
        r: index(0-index)
        """
        res = self.ide_ele

        l += self.num
        r += self.num
        while l < r:
            if l & 1:
                res = self.segfunc(res, self.tree[l])
                l += 1
            if r & 1:
                res = self.segfunc(res, self.tree[r - 1])
            l >>= 1
            r >>= 1
        return res

a = [14, 5, 9, 13, 7, 12, 11, 1, 7, 8]

seg = SegTree(a, segfunc, ide_ele)

print(seg.query(0, 8))
seg.update(5, 0)
print(seg.query(0, 8))

デバッグ用コード

def check(r,l):
    print([seg.query(i,i+1) for i in range(r,l)])

出題:ABC339E, ABC341E

遅延セグ木

参考記事を見てください
下の例のクエリは区間和を求めます.

例) 遅延セグ木(区間加算)
def segfunc(x,y):
    return x+y
class LazySegTree_RAQ:
    def __init__(self,init_val,segfunc,ide_ele):
        n = len(init_val)
        self.segfunc = segfunc
        self.ide_ele = ide_ele
        self.num = 1<<(n-1).bit_length()
        self.tree = [ide_ele]*2*self.num
        self.lazy = [0]*2*self.num
        for i in range(n):
            self.tree[self.num+i] = init_val[i]
        for i in range(self.num-1,0,-1):
            self.tree[i] = self.segfunc(self.tree[2*i], self.tree[2*i+1])
    def gindex(self,l,r):
        l += self.num
        r += self.num
        lm = l>>(l&-l).bit_length()
        rm = r>>(r&-r).bit_length()
        while r>l:
            if l<=lm:
                yield l
            if r<=rm:
                yield r
            r >>= 1
            l >>= 1
        while l:
            yield l
            l >>= 1
    def propagates(self,*ids):
        for i in reversed(ids):
            v = self.lazy[i]
            if v==0:
                continue
            self.lazy[i] = 0
            self.lazy[2*i] += v
            self.lazy[2*i+1] += v
            self.tree[2*i] += v
            self.tree[2*i+1] += v
    def add(self,l,r,x):
        ids = self.gindex(l,r)
        l += self.num
        r += self.num
        while l<r:
            if l&1:
                self.lazy[l] += x
                self.tree[l] += x
                l += 1
            if r&1:
                self.lazy[r-1] += x
                self.tree[r-1] += x
            r >>= 1
            l >>= 1
        for i in ids:
            self.tree[i] = self.segfunc(self.tree[2*i], self.tree[2*i+1]) + self.lazy[i]
    def query(self,l,r):
        self.propagates(*self.gindex(l,r))
        res = self.ide_ele
        l += self.num
        r += self.num
        while l<r:
            if l&1:
                res = self.segfunc(res,self.tree[l])
                l += 1
            if r&1:
                res = self.segfunc(res,self.tree[r-1])
            l >>= 1
            r >>= 1
        return res

例) 遅延セグ木(区間更新)
def segfunc(x,y):
    return min(x,y)
class LazySegTree_RUQ:
    def __init__(self,init_val,segfunc,ide_ele):
        n = len(init_val)
        self.segfunc = segfunc
        self.ide_ele = ide_ele
        self.num = 1<<(n-1).bit_length()
        self.tree = [ide_ele]*2*self.num
        self.lazy = [None]*2*self.num
        for i in range(n):
            self.tree[self.num+i] = init_val[i]
        for i in range(self.num-1,0,-1):
            self.tree[i] = self.segfunc(self.tree[2*i],self.tree[2*i+1])
    def gindex(self,l,r):
        l += self.num
        r += self.num
        lm = l>>(l&-l).bit_length()
        rm = r>>(r&-r).bit_length()
        while r>l:
            if l<=lm:
                yield l
            if r<=rm:
                yield r
            r >>= 1
            l >>= 1
        while l:
            yield l
            l >>= 1
    def propagates(self,*ids):
        for i in reversed(ids):
            v = self.lazy[i]
            if v is None:
                continue
            self.lazy[i] = None
            self.lazy[2*i] = v
            self.lazy[2*i+1] = v
            self.tree[2*i] = v
            self.tree[2*i+1] = v
    def update(self,l,r,x):
        ids = self.gindex(l,r)
        self.propagates(*self.gindex(l,r))
        l += self.num
        r += self.num
        while l<r:
            if l&1:
                self.lazy[l] = x
                self.tree[l] = x
                l += 1
            if r&1:
                self.lazy[r-1] = x
                self.tree[r-1] = x
            r >>= 1
            l >>= 1
        for i in ids:
            self.tree[i] = self.segfunc(self.tree[2*i], self.tree[2*i+1])
    def query(self,l,r):
        ids = self.gindex(l,r)
        self.propagates(*self.gindex(l,r))
        res = self.ide_ele
        l += self.num
        r += self.num
        while l<r:
            if l&1:
                res = self.segfunc(res,self.tree[l])
                l += 1
            if r&1:
                res = self.segfunc(res,self.tree[r-1])
            l >>= 1
            r >>= 1
        return res

出題:ABC340E

UnionFind

PythonでのUnion-Find(素集合データ構造)の実装と使い方

UnionFind

from collections import defaultdict

class UnionFind():
    def __init__(self, n):
        self.n = n
        self.parents = [-1] * n

    def find(self, x):
        if self.parents[x] < 0:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        if self.parents[x] > self.parents[y]:
            x, y = y, x

        self.parents[x] += self.parents[y]
        self.parents[y] = x

    def size(self, x):
        return -self.parents[self.find(x)]

    def same(self, x, y):
        return self.find(x) == self.find(y)

    def members(self, x):
        root = self.find(x)
        return [i for i in range(self.n) if self.find(i) == root]

    def roots(self):
        return [i for i, x in enumerate(self.parents) if x < 0]

    def group_count(self):
        return len(self.roots())

    def all_group_members(self):
        group_members = defaultdict(list)
        for member in range(self.n):
            group_members[self.find(member)].append(member)
        return group_members

    def __str__(self):
        return '\n'.join(f'{r}: {m}' for r, m in self.all_group_members().items())

UnionFindLabel

class UnionFindLabel(UnionFind):
    def __init__(self, labels):
        assert len(labels) == len(set(labels))

        self.n = len(labels)
        self.parents = [-1] * self.n
        self.d = {x: i for i, x in enumerate(labels)}
        self.d_inv = {i: x for i, x in enumerate(labels)}

    def find_label(self, x):
        return self.d_inv[super().find(self.d[x])]

    def union(self, x, y):
        super().union(self.d[x], self.d[y])

    def size(self, x):
        return super().size(self.d[x])

    def same(self, x, y):
        return super().same(self.d[x], self.d[y])

    def members(self, x):
        root = self.find(self.d[x])
        return [self.d_inv[i] for i in range(self.n) if self.find(i) == root]

    def roots(self):
        return [self.d_inv[i] for i, x in enumerate(self.parents) if x < 0]

    def all_group_members(self):
        group_members = defaultdict(list)
        for member in range(self.n):
            group_members[self.d_inv[self.find(member)]].append(self.d_inv[member])
        return group_members

Python:重み付きUnion-Find木について
併合するとき要素間の距離も織り込めるイメージです。
F - Good Set Query
ACコード

UnionFindLabel

class WeightedUnionFind:
    def __init__(self, n):
        self.par = [i for i in range(n+1)]
        self.rank = [0] * (n+1)
        # 根への距離を管理
        self.weight = [0] * (n+1)

    # 検索
    def find(self, x):
        if self.par[x] == x:
            return x
        else:
            y = self.find(self.par[x])
            # 親への重みを追加しながら根まで走査
            self.weight[x] += self.weight[self.par[x]]
            self.par[x] = y
            return y

    # 併合
    def union(self, x, y, w):
        rx = self.find(x)
        ry = self.find(y)
        # xの木の高さ < yの木の高さ
        if self.rank[rx] < self.rank[ry]:
            self.par[rx] = ry
            self.weight[rx] = w - self.weight[x] + self.weight[y]
        # xの木の高さ ≧ yの木の高さ
        else:
            self.par[ry] = rx
            self.weight[ry] = -w - self.weight[y] + self.weight[x]
            # 木の高さが同じだった場合の処理
            if self.rank[rx] == self.rank[ry]:
                self.rank[rx] += 1

    # 同じ集合に属するか
    def same(self, x, y):
        return self.find(x) == self.find(y)

    # xからyへのコスト
    def diff(self, x, y):
        return self.weight[x] - self.weight[y]

ナップサック問題

ナップサック問題はいつもyaketake08’s 実装メモを拝借しています.
W以上で価値を最小化する問題を解く関数. ACコード(ABC317)

def solve(N, W, ws, vs):
    # i番目の重み ws[i],価値 vs[i]
    # 重み総和が W を超える最小の価値を返す
    # O(N*sum(ws))
    sumw = sum(ws)
    dp = [float('inf')] * (sumw+1)
    dp[0] = 0
    for i in range(N):
        v = vs[i]; w = ws[i]
        for j in range(sumw, w-1, -1):
            dp[j] = min(dp[j-w] + v, dp[j])
    return min(dp[W:])

Sortedset

SortedMultiset

転倒数を求める

ABC232-F
ABC244-D
ABC264-D
ABC332-D
ARC120-C

# 0-indexの数列で使える
def inversion(inds,N):
    bit = [0] * (N+1)
    def bit_add(x,w):
        while x <= N:
            bit[x] += w
            x += (x & -x)
    def bit_sum(x):
        ret = 0
        while x > 0:
            ret += bit[x]
            x -= (x & -x)
        return ret
    inv = 0
    for ind in reversed(inds):
        inv += bit_sum(ind + 1)
        bit_add(ind + 1, 1)
    return inv
    

ABC244D: Swap Hats

# 任意の配列で使えるバージョン
def inversion(in_arr):
    A,B = in_arr, sorted(in_arr)
    arr = []
    dic = {}
    for i,b in enumerate(B):
        if b in dic:
            arr[dic[b]].append(i)
        else:
            dic[b] = len(arr)
            arr.append([i])
    for i in range(len(arr)):
        arr[i].reverse()
    inds = []
    for a in A:
        inds.append(arr[dic[a]].pop())
    L = len(inds)
    bit = [0] * (L+1)
    def bit_add(x,w):
        while x <= L:
            bit[x] += w
            x += (x & -x)
    def bit_sum(x):
        ret = 0
        while x > 0:
            ret += bit[x]
            x -= (x & -x)
        return ret
    inv = 0
    for ind in reversed(inds):
        inv += bit_sum(ind + 1)
        bit_add(ind + 1, 1)
    return inv
    

ABC306F

素因数分解

nを素因数分解したリストを返す.
よく使います.計算量はO(√n)

参考記事:8を入れたら[[2,3]]が返ってくる.Countの手間が省けます.

def factorization(n):
    arr = []
    temp = n
    for i in range(2, int(-(-n**0.5//1))+1):
        if temp%i==0:
            cnt=0
            while temp%i==0:
                cnt+=1
                temp //= i
            arr.append([i, cnt])

    if temp!=1:
        arr.append([temp, 1])

    if arr==[]:
        arr.append([n, 1])

    return arr

1からnまでのxor

参考
https://atcoder.jp/contests/abc121/tasks/abc121_d

prime_list.py
def xor_sum(n):
    ans = 0
    keta = len(bin(n))-2
    n += 1
    for i in range(keta):
        block = 1<<(i+1)
        ans += (((n//block) * (block // 2) + max(0,n % block - block // 2)) &  1) << i
    return ans

拡張ユークリッド

二次元行列をListのままくるくるする

def reverse(H,W,L):
# 反転
  LL = [[0 for _ in range(H)] for _ in range(W)]
  for i in range(H):
    for j in range(W):
      LL[j][i] = L[i][j]
  return LL

def rotate(H,W,L):
# 時計回り90度回転
  LL = [[0 for _ in range(H)] for _ in range(W)]
  for i in range(H):
    for j in range(W):
      LL[j][-1-i] = L[i][j]
  return LL

参考文献

Markdown記法チートシート
Qiita記事作成方法 初心者の備忘録
Qiitaに投稿するときの心構え
競プロpython

1
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?