はじめに
競技プログラミングでよく使うコードをテンプレート化した。
コピペで使えるように動作確認済み。
入力の高速化
import sys
input = sys.stdin.readline # これだけで高速化
# 1つの整数
N = int(input())
# 複数の整数
N, M = map(int, input().split())
# リスト
A = list(map(int, input().split()))
# 文字列(改行を除去)
S = input().strip()
数学系
最大公約数・最小公倍数
def gcd(a, b):
while b:
a, b = b, a % b
return a
def lcm(a, b):
return a * b // gcd(a, b)
# Python 3.9+ なら math.gcd, math.lcm が使える
素数判定
def is_prime(n):
if n < 2:
return False
if n == 2:
return True
if n % 2 == 0:
return False
for i in range(3, int(n**0.5) + 1, 2):
if n % i == 0:
return False
return True
エラトステネスの篩
def sieve(n):
"""n以下の素数をリストで返す"""
is_prime = [True] * (n + 1)
is_prime[0] = is_prime[1] = False
for i in range(2, int(n**0.5) + 1):
if is_prime[i]:
for j in range(i*i, n + 1, i):
is_prime[j] = False
return [i for i in range(n + 1) if is_prime[i]]
sieve(100)
# [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97]
素因数分解
def factorize(n):
factors = []
d = 2
while d * d <= n:
while n % d == 0:
factors.append(d)
n //= d
d += 1
if n > 1:
factors.append(n)
return factors
factorize(60) # [2, 2, 3, 5]
約数列挙
def divisors(n):
divs = []
for i in range(1, int(n**0.5) + 1):
if n % i == 0:
divs.append(i)
if i != n // i:
divs.append(n // i)
return sorted(divs)
divisors(60) # [1, 2, 3, 4, 5, 6, 10, 12, 15, 20, 30, 60]
modでの計算
MOD = 998244353
def mod_pow(base, exp, mod=MOD):
"""繰り返し二乗法"""
result = 1
base %= mod
while exp > 0:
if exp & 1:
result = result * base % mod
exp >>= 1
base = base * base % mod
return result
def mod_inv(a, mod=MOD):
"""modの逆元(フェルマーの小定理)"""
return mod_pow(a, mod - 2, mod)
階乗・二項係数
class Factorial:
def __init__(self, n, mod=MOD):
self.mod = mod
self.fact = [1] * (n + 1)
self.inv_fact = [1] * (n + 1)
for i in range(1, n + 1):
self.fact[i] = self.fact[i-1] * i % mod
self.inv_fact[n] = mod_pow(self.fact[n], mod - 2, mod)
for i in range(n - 1, -1, -1):
self.inv_fact[i] = self.inv_fact[i+1] * (i + 1) % mod
def comb(self, n, r):
if r < 0 or r > n:
return 0
return self.fact[n] * self.inv_fact[r] % self.mod * self.inv_fact[n-r] % self.mod
f = Factorial(100)
f.comb(10, 3) # 120
f.comb(100, 50) # 198626801 (mod 998244353)
Union-Find
class UnionFind:
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n
self.size = [1] * n
def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x]) # 経路圧縮
return self.parent[x]
def union(self, x, y):
px, py = self.find(x), self.find(y)
if px == py:
return False
if self.rank[px] < self.rank[py]:
px, py = py, px
self.parent[py] = px
self.size[px] += self.size[py]
if self.rank[px] == self.rank[py]:
self.rank[px] += 1
return True
def same(self, x, y):
return self.find(x) == self.find(y)
def get_size(self, x):
return self.size[self.find(x)]
# 使用例
uf = UnionFind(5)
uf.union(0, 1)
uf.union(2, 3)
uf.union(1, 3)
uf.same(0, 2) # True
uf.get_size(0) # 4
グラフ探索
BFS
from collections import deque
def bfs(graph, start):
"""幅優先探索で最短距離を求める"""
n = len(graph)
dist = [-1] * n
dist[start] = 0
queue = deque([start])
while queue:
v = queue.popleft()
for u in graph[v]:
if dist[u] == -1:
dist[u] = dist[v] + 1
queue.append(u)
return dist
# graph[v] = [隣接頂点のリスト]
graph = [[1, 2], [0, 3], [0, 3], [1, 2, 4], [3]]
bfs(graph, 0) # [0, 1, 1, 2, 3]
ダイクストラ法
from heapq import heappush, heappop
def dijkstra(graph, start):
"""graph[v] = [(u, cost), ...]"""
n = len(graph)
dist = [float('inf')] * n
dist[start] = 0
heap = [(0, start)]
while heap:
d, v = heappop(heap)
if d > dist[v]:
continue
for u, cost in graph[v]:
if dist[v] + cost < dist[u]:
dist[u] = dist[v] + cost
heappush(heap, (dist[u], u))
return dist
# 使用例
weighted_graph = [
[(1, 2), (2, 5)], # 0 → 1(cost=2), 0 → 2(cost=5)
[(0, 2), (2, 1), (3, 3)],
[(0, 5), (1, 1), (3, 1)],
[(1, 3), (2, 1)]
]
dijkstra(weighted_graph, 0) # [0, 2, 3, 4]
二分探索
from bisect import bisect_left, bisect_right
arr = [1, 2, 4, 4, 4, 7, 9]
bisect_left(arr, 4) # 2 (4が入る最左位置)
bisect_right(arr, 4) # 5 (4が入る最右位置)
# 条件を満たす最小値を見つける
def binary_search(ok, ng, condition):
while abs(ok - ng) > 1:
mid = (ok + ng) // 2
if condition(mid):
ok = mid
else:
ng = mid
return ok
累積和
from itertools import accumulate
arr = [1, 2, 3, 4, 5]
prefix = list(accumulate(arr, initial=0))
# [0, 1, 3, 6, 10, 15]
# 区間[l, r)の和
def range_sum(l, r):
return prefix[r] - prefix[l]
メモ化再帰
from functools import lru_cache
@lru_cache(maxsize=None)
def fib(n):
if n < 2:
return n
return fib(n-1) + fib(n-2)
fib(100) # 瞬時に計算
よく使うimport
import sys
from collections import defaultdict, deque, Counter
from heapq import heappush, heappop, heapify
from bisect import bisect_left, bisect_right
from itertools import permutations, combinations, accumulate
from functools import lru_cache
import math
input = sys.stdin.readline
sys.setrecursionlimit(10**6) # 再帰上限を上げる
計算量の目安
| N | 許容される計算量 |
|---|---|
| $10^6$ | O(N) |
| $10^5$ | O(N log N) |
| $10^4$ | O(N^2) |
| $500$ | O(N^3) |
| $20$ | O(2^N) |
| $10$ | O(N!) |
まとめ
| カテゴリ | 関数/クラス |
|---|---|
| 数学 |
is_prime, sieve, factorize, divisors
|
| mod計算 |
mod_pow, mod_inv, Factorial
|
| データ構造 | UnionFind |
| グラフ |
bfs, dijkstra
|
| 探索 |
bisect_left, bisect_right
|
| その他 |
accumulate, lru_cache
|
このテンプレートをベースに、カスタマイズしてください!