AtCoderジャッジアップデートで解禁されたcodonを実際に使ってみましょう。
本記事はテンプレ・ライブラリ編です。
変更点編は以下のリンクからご覧ください。
更新履歴
2025/11/08 ライブラリにstringと行列累乗を追加しました。
2025/11/19 ライブラリに最小費用流を追加しました。
2025/11/21 ライブラリに畳み込みを追加しました。
2025/11/25
- テンプレにint ハッシュ値変更を追加しました。
- ライブラリにfloor_sum・disjoint Sparse Table・Wavelet Matrixを追加しました。
変更点のまとめ 3行で
- codonは型厳格です。Pythonの型ヒントの構文で型を指定しましょう。
- intが符号つき64bit整数になりました。オーバーフローに注意してください。
- 機能は
@extendで追加・変更できます。足りない機能はDIYしましょう。
PyPy3のコードをcodonで動かす手順提案
先述の変更点に注意しながらコーディング・・・というのは少々大変なので、ここでは最小限の変更で、PyPy3のコードをcodonに移植する実験手順を提案します。
まずは動けばラッキーくらいの気持ちで試してみてください。
-
PyPy3のジャッジ結果が AC・TLE・MLE となるコードを用意します。
WA・REが出る場合は先に修正してください。 -
必ず 移植前にPyPy3のコードを保存してください。
codonが動かなかった場合のコード復元ができるようにしてください。 -
コードの冒頭に、後述のテンプレを貼ってください。
-
PyPy3のライブラリを使用する場合、インスタンス変数の型指定をしてください。
型はクラス変数と同様の位置に書きます。分からなければNoneを代入してみてください。
既存のcodon向けライブラリに差し替えてもよいでしょう。 -
コードテストにかけてみて、動けばラッキーです。動かなければ諦めてください。
テンプレ例
#mapの返り値をlist(map)で固定
import internal.static as _internal_static
def map(f, *args) -> list:
if _internal_static.len(args) == 0:
compile_error("map() expects at least one iterator")
elif _internal_static.len(args) == 1:
return [f(a) for a in args[0]]
else:
return [f(*a) for a in zip(*args)]
#int同士の除算結果をPythonの負の無限大丸めに合わせる
@extend
class int:
@pure
@llvm
def _floordiv_int_int(self: int, other: int) -> int:
%0 = sdiv i64 %self, %other
ret i64 %0
@overload
def __floordiv__(self, other: int):
d = self._floordiv_int_int(other)
m = self - d * other
if m and ((other ^ m) < 0):
d -= 1
return d
@pure
@llvm
def _mod_int_int(self: int, other: int) -> int:
%0 = srem i64 %self, %other
ret i64 %0
@overload
def __mod__(self, other: int) -> int:
m = self._mod_int_int(other)
if m and ((other ^ m) < 0):
m += other
return m
#Int[N](N <= 128)同士の除算結果をPythonの負の無限大丸めに合わせる
@extend
class Int:
def __floordiv__(self, other: Int[N]) -> Int[N]:
if N > 128:
compile_error("division is not supported on Int[N] when N > 128")
d = self._floordiv(other)
m = self - d * other
if m and ((other ^ m) < Int[N](0)):
d -= Int[N](1)
return d
def __mod__(self, other: Int[N]) -> Int[N]:
if N > 128:
compile_error("modulus is not supported on Int[N] when N > 128")
m = self._mod(other)
if m and ((other ^ m) < Int[N](0)):
m += other
return m
#int.bit_length, int.bit_countに対応
@extend
class int:
def bit_length(self):
return 64 - abs(self).__ctlz__()
def bit_count(self):
return abs(self).__ctpop__()
#floatの出力桁数を15桁に増やす
@extend
class float:
def __str__(self): return f'{self:15f}'
#巨大mod時のオーバーフローを回避 pow(base, -1, mod)に対応
def _extended_pow():
_builtin_pow = pow
def _codon_pow(base, exp):
return _builtin_pow(base, exp)
@overload
def _codon_pow(base: int, exp: int, mod: int) -> int:
'''
codon用に pow(base, exp, mod) を拡張した関数です。
1. (abs(mod) - 1) ** 2 >= 1 << 63 の場合に発生していたオーバーフローを回避しました。
2. pow(base: int, exp: 負整数, mod: int) による逆元計算に対応しました。
返り値の符号は mod の符号に一致します。
いずれかの引数に INF_MIN := -1 << 63 を渡した場合の動作は未定義です。
'''
if mod == 0:
raise ValueError('pow() 3rd argument cannot be 0')
if mod == 1 or mod == -1:
return 0
if exp < 0: #拡張ユークリッドの互除法
a, b, x, y = base, mod, 1, 0
while b:
q = a // b
a, b, x, y = b, a - q * b, y, x - q * y
if a != 1 and a != -1:
raise ValueError('base is not invertible for the given modulus')
b128, m128 = Int[128](x), Int[128](mod)
if a == -1:
b128 = - b128
exp = - exp
else:
b128, m128 = Int[128](base), Int[128](mod)
v128 = Int[128](1)
while exp:
if exp & 1 == 1:
v128 = v128 * b128 % m128
b128 = b128 * b128 % m128
exp >>= 1
v = int(v128)
return v + mod if v != 0 and (0 < v) != (0 < mod) else v
return _codon_pow
pow = _extended_pow()
もう少しだけ粘りたい場合、以下の点をチェックしてみてもいいですが、無理なときは無理なのでほどほどのところで切り上げてください。
-
sys.stdin.readlineやsys.setrecursionlimitといった、codon非対応のものは削除してください。 - match / case文や
:=(セイウチ演算子)はバグの原因なので差し替えてください。 - オーバーフローしていそうな場合、
intの代わりにi128を使うと改善するかもしれません。ただし、計算量保証はなくなります。 - setやdictにタプルを入れたい場合は、タプルの代わりにリストのまま入れてください。
- codonのsetやdictはハッシュの衝突が起きやすいです。TLEが取れない場合、補足の章にあるintハッシュ変更の拡張機能を追加してください。
補足: テンプレについて
codonは@extendや@overloadで多くの機能を変更できます。
筆者の改変例をいくつか提示します。使えそうなものだけ持って行ってください。
map入力受取
map(int, input().split())で入力を受け取れるようにします。
本来は手動でlist(map(int, input().split())) に差し替えるのがベストですが、少し面倒なので2通りの対応例を提示します。
どちらか片方を選んでご利用ください。
#generatorに__getitem__, __contains__を定義
@extend
class Generator:
def __getitem__(self: Generator[T], _: int) -> T:
if self.done(): raise StopIteration()
return self.next()
def __getitem__(self: Generator[T], key: slice) -> list[T]:
assert key.stop is None, (
'''本拡張では、末尾以外でのアンパッキングはできません。
例として、 N, *A = generator には対応しますが *A, N = generator は非対応です。
右辺の generator を list(generator) と書き換えてみてください。''')
return list(self)
def __contains__(self: Generator[T], key: T) -> bool:
return key in list(self)
#mapの返り値をlist(map)で固定
import internal.static as _internal_static
def map(f, *args) -> list:
if _internal_static.len(args) == 0:
compile_error("map() expects at least one iterator")
elif _internal_static.len(args) == 1:
return [f(a) for a in args[0]]
else:
return [f(*a) for a in zip(*args)]
整数 除算方向の変更
Pythonと同様の負の無限大丸めに変更します。
intとInt(符号つき任意倍長整数)の対応例を示します。
#int同士の除算結果をPythonの負の無限大丸めに合わせる
@extend
class int:
@pure
@llvm
def _floordiv_int_int(self: int, other: int) -> int:
%0 = sdiv i64 %self, %other
ret i64 %0
@overload
def __floordiv__(self, other: int):
d = self._floordiv_int_int(other)
m = self - d * other
if m and ((other ^ m) < 0):
d -= 1
return d
@pure
@llvm
def _mod_int_int(self: int, other: int) -> int:
%0 = srem i64 %self, %other
ret i64 %0
@overload
def __mod__(self, other: int) -> int:
m = self._mod_int_int(other)
if m and ((other ^ m) < 0):
m += other
return m
#Int[N](N <= 128)同士の除算結果をPythonの負の無限大丸めに合わせる
@extend
class Int:
def __floordiv__(self, other: Int[N]) -> Int[N]:
if N > 128:
compile_error("division is not supported on Int[N] when N > 128")
d = self._floordiv(other)
m = self - d * other
if m and ((other ^ m) < Int[N](0)):
d -= Int[N](1)
return d
def __mod__(self, other: Int[N]) -> Int[N]:
if N > 128:
compile_error("modulus is not supported on Int[N] when N > 128")
m = self._mod(other)
if m and ((other ^ m) < Int[N](0)):
m += other
return m
bit_count, bit_length
intクラスには__ctlz__, __cttz__, __ctpop__の命令が追加されているので、これを利用して実装します。
なお、PythonやPyPyのbit_count・bit_lengthと異なり非常に高速です。
#int.bit_length, int.bit_countに対応
@extend
class int:
def bit_length(self):
return 64 - abs(self).__ctlz__()
def bit_count(self):
return abs(self).__ctpop__()
int ハッシュ値変更
(2025/11/25 追記)
調査により、codonのset・dictはハッシュの衝突に脆弱だと判明しました。
原因の詳細は省きますが特に2冪の入力に弱く、例として$2^k$の倍数の入力を行うだけでハッシュが完全に衝突してしまいます。
なのでハッシュの衝突が目立つ場合はハッシュ値のアルゴリズムを変更して対応しましょう。以下はSplitMix64を用いた実装例です。
#int hash値をSplitMix64で変更する
#Reference: https://prng.di.unimi.it/splitmix64.c
@extend
class int:
def __hash__(self) -> int:
z: UInt[64] = UInt[64](self) + UInt[64](0x9e3779b97f4a7c15)
z = (z ^ (z >> UInt[64](30))) * UInt[64](0xbf58476d1ce4e5b9)
z = (z ^ (z >> UInt[64](27))) * UInt[64](0x94d049bb133111eb)
return int(z ^ (z >> UInt[64](31)))
float 出力桁数増加
print(float)での出力桁数はf stringで変更できます。
なお、float128の出力桁数変更は現在非対応です。
#floatの出力桁数を15桁に増やす
@extend
class float:
def __str__(self): return f'{self:15f}'
pow オーバーフロー回避・逆元計算追加
codonのpowはオーバーフローするうえ、pow(base, -1, mod)の逆元計算にも非対応でやや不便です。
早速改造しましょう。
この実装例では、2引数ならビルトインpowを呼び出し、3引数ならオリジナルpowを呼び出すように工夫しています。
また、繰り返し二乗法で$O(log exp)$回の128bit除算を行います。手元でテストする限りでは十分高速でしたが、この点ご留意ください。
#巨大mod時のオーバーフローを回避 pow(base, -1, mod)に対応
def _extended_pow():
_builtin_pow = pow
def _codon_pow(base, exp):
return _builtin_pow(base, exp)
@overload
def _codon_pow(base: int, exp: int, mod: int) -> int:
'''
codon用に pow(base, exp, mod) を拡張した関数です。
1. (abs(mod) - 1) ** 2 >= 1 << 63 の場合に発生していたオーバーフローを回避しました。
2. pow(base: int, exp: 負整数, mod: int) による逆元計算に対応しました。
返り値の符号は mod の符号に一致します。
いずれかの引数に INF_MIN := -1 << 63 を渡した場合の動作は未定義です。
'''
if mod == 0:
raise ValueError('pow() 3rd argument cannot be 0')
if mod == 1 or mod == -1:
return 0
if exp < 0: #拡張ユークリッドの互除法
a, b, x, y = base, mod, 1, 0
while b:
q = a // b
a, b, x, y = b, a - q * b, y, x - q * y
if a != 1 and a != -1:
raise ValueError('base is not invertible for the given modulus')
b128, m128 = Int[128](x), Int[128](mod)
if a == -1:
b128 = - b128
exp = - exp
else:
b128, m128 = Int[128](base), Int[128](mod)
v128 = Int[128](1)
while exp:
if exp & 1 == 1:
v128 = v128 * b128 % m128
b128 = b128 * b128 % m128
exp >>= 1
v = int(v128)
return v + mod if v != 0 and (0 < v) != (0 < mod) else v
return _codon_pow
pow = _extended_pow()
補足: ライブラリについて
PyPy3のライブラリに型を追加すれば大抵は動くようになりますが、一部移植が難しいものも存在します。
ここでは移植の参考として、筆者の実装例を提示します。
注意点
- PyPy3の自作ライブラリの移植です。説明のために関数名だけはACL風に寄せてみましたが、内部実装は全く異なります。
- ACLにない機能も含まれます。
- 遅いものも含まれます。特に、SortedSet・SortedListはものすごく遅いです。
- 最低限のランダムテストしか行っていません。バグはご容赦ください。
- これからACLを移植する方は、PythonからではなくC++から移植した方が速度が出ると思います。あとgithubとかを使った方がいいです
参考文献
UnionFind
for codon・PyPy3
Fenwick Tree
for codon・PyPy3
for codon
Segment Tree
for codon
Lazy Segment Tree
実装はアルゴリズム実技検定 公式テキスト[上級]~[エキスパート]編に影響を受けています。
合成関数の方向や木内二分探索の定義が間違っているかもしれません。
for codon
disjoint Sparse Table
disjoint Sparse Table for codon
#disjoint Sparse Table for codon
class disjointSparseTable[Te, Tf]:
'''
disjoint Sparse Table for codon
Θ(NlogN)の前計算の上で、O(1)で区間積を計算します。
A: 読み込ませる配列
identity_e: 単位元 要素の型はAと同じにしてください
node_f: 合成関数 f(node_Lt: Te, node_Rt: Te) -> node_new: Te
'''
N: int
_e: Te
_f: Tf
_node: list[Te]
__slots__ = ('N', '_e', '_f', '_node')
def __init__(self, A: list[Te], identity_e: Te, node_f: Tf) -> None:
self.N = N = len(A)
logN: int = max(1, len(bin(N - 1)) - 2) #(N - 1).bit_length()
self._e, self._f = identity_e, node_f
self._node = node = [self._e for _ in range(N * logN)]
for h in range(logN):
offset: int = h * N
for i, Ai in enumerate(A, start = offset):
node[i] = Ai
b = diff = 1 << h
step: int = 2 << h
while b < N:
node[b + offset] = back = A[b]
i: int = b + 1
Rt: int = min(b + diff, N)
while i < Rt:
node[i + offset] = back = self._f(back, A[i])
i += 1
b += step
b: int = diff - 1
while b < N:
node[b + offset] = back = A[b]
i: int = b - 1
Lt: int = b - diff
while Lt < i:
node[i + offset] = back = self._f(A[i], back)
i -= 1
b += step
def fold(self, Lt: int, Rt: int) -> Te:
'半開区間積A[Lt, Rt)を取得します。Lt == Rtの場合、単位元eを返します。'
assert 0 <= Lt <= Rt <= self.N
if Lt == Rt:
return self._e
Rt -= 1
if Lt == Rt:
return self._node[Lt]
h: int = 63 - (Lt ^ Rt).__ctlz__() #h ← (Lt ^ Rt).bit_length() - 1
return self._f( self._node[h * self.N + Lt], self._node[h * self.N + Rt] )
SCC
for codon・PyPy3
最大流
for codon
最小費用流
本実装はPyPy3の実装を流用しており、ダイクストラ法にセグメント木を用いています。
ですがcodonはheapqが十分に高速なので、ライブラリ整備の際はheapqでの実装をおすすめします。
for codon
二部グラフマッチング
for codon
suffix array, Z algorithm
内部実装は特に荒れています。
for codon, PyPy3
畳み込み
for codon, PyPy3
math
floor_sumは未完成で、特にオーバーフローの挙動が不安定です。
isqrt, inv_mod, gcd, ext_gcd, CRT for codon
#floor(√n)
def isqrt(n: int) -> int:
'floor(√n): m ** 2 <= n < (m + 1) ** 2 を満たす非負整数mを求めます。'
assert n >= 0
if n >= 3037000499 ** 2: #floor( √(2 ** 63 - 1) ) = 3037000499
return 3037000499
m: int = max(0, int(float(n).sqrt())) #int(n ** 0.5)
m2: int = m * m
while m2 < n:
m2 += m << 1 | 1
m += 1
while m2 > n:
m -= 1
m2 -= m << 1 | 1
return m
#pow(base, -1, mod)
def inv_mod(base: int, mod: int) -> int:
'pow(base, -1, mod) を求めます。返り値の符号はmodの符号と一致します。'
assert mod != 0, f'mod must not be zero. {mod = }'
if mod == 1 or mod == -1:
return 0
a, b, x, y = base, mod, 1, 0
while b:
q = a // b
a, b, x, y = b, a - q * b, y, x - q * y
if a != 1 and a != -1:
raise ValueError('base is not invertible for the given modulus')
if a == -1:
x = - x
return x + mod if (x ^ mod) < 0 else x
#最大公約数
gcd = lambda x, y: gcd(y, x % y) if y else abs(x)
#拡張ユークリッドの互除法
def ext_gcd(a: int, b: int) -> tuple[int, int, int]:
'''
g == a * x + b * y を満たす(g, x, y)を返します。
a == b == 0 の場合、(g, x, y) = (0, 1, 0) とします。
そうでない場合、(g, x, y)は以下の条件を満たします。
g = gcd(a, b) > 0
abs(x) <= max(1, abs(b // g))
abs(y) <= max(1, abs(a // g))
'''
if b == 0:
return (a, 1, 0) if a >= 0 else (- a, - 1, 0)
g, x, y = ext_gcd(b, a % b)
return g, y, x - (a // b) * y
#中国剰余定理
def CRT(R: list[int], M: list[int]) -> tuple[int, int]:
'''
n ≡ R[i] mod M[i] をすべて満たす非負整数n < lcm(M)を求め、(n, lcm(M))を返します。
答えがない場合は(0, 0)を、len(R) == len(M) == 0の場合は(0, 1)を返します。
制約: len(R) == len(M), 0 < M[i], lcm(M) < 2 ** 63
'''
assert len(R) == len(M)
assert all(0 < Mi for Mi in M)
R1, M1 = 0, 1
for R2, M2 in zip(R, M):
R2 %= M2
if R2 < 0:
R2 += M2
if M1 > M2:
R1, M1, R2, M2 = R2, M2, R1, M1
f, g, i, j = M1, M2, 1, 0 #g: gcd(M1, M2), i: invmod(M1 // g, M2 // g)
while f:
h = g // f
f, g, i, j = g - h * f, f, j, i - h * j
p, q = R1 - R2, M1 // g
r, s = p // g, p % g
if s:
return (0, 0)
R1, M1 = r * i % q * M2 + R2, M2 * q #assert abs(r * i) < M2 * q
if R1 < 0:
R1 += M1
return (R1, M1)
#floor sum for codon
#Reference: https://qiita.com/AkariLuminous/items/3e2c80baa6d5e6f3abe9
def floor_sum[T](n: T, m: T, a: T, b: T) -> T:
'''
sum( floor( (ai + b) / m ) for i in range(n) ) をO(log m)で求めます。
制約: 0 < m, 型は整数, ai + bがオーバーフローしない
'''
zero: T = n ^ n
one: T = - ~ zero
assert zero < m, f'mが正整数ではありません。{m = }'
if n <= zero:
return zero
ans: T = zero
while True:
if not zero <= a < m:
a_div, a = divmod(a, m)
ans += ( ((n - one) >> one) * n if n & one else (n - one) * (n >> one) ) * a_div
if not zero <= b < m:
b_div, b = divmod(b, m)
ans += n * b_div
y_max: T = a * n + b
if y_max < m:
return ans
else:
y_div, y_mod = divmod(y_max, m)
n, m, a, b = y_div, a, m, y_mod
素因数分解
高速素因数分解 for codon
#高速素因数分解 for codon
#Reference: https://qiita.com/t_fuki/items/7cd50de54d3c5d063b4a
class prime:
#内部関数
def _miller_rabin(N: int) -> bool:
if N < 2 or N & 1 == 0:
return N == 2
M, e = N - 1, (N - 1).__cttz__() #e = (M & - M).bit_length() - 1
d = M >> e #M = N - 1 = d << e
N128, M128 = UInt[128](N), UInt[128](M)
for a in ([2, 7, 61] if N < 48781 * 97561 else
[2, 325, 9375, 28178, 450775, 9780504, 1795265022]):
if a >= N:
continue
c = d
x128, y128 = UInt[128](1), UInt[128](a) #x = pow(a, d, N)
while c: #x = pow(a, d, N)
if c & 1:
x128 = x128 * y128 % N128
y128 = y128 * y128 % N128
c >>= 1
if x128 == UInt[128](1): #x = pow(a, d, N) ≡ 1 ならおそらく素数
continue
while x128 != M128: #pow(x, 2 ** (c := e未満), N) ≡ -1 ならおそらく素数
x128 = x128 * x128 % N128
c += 1
if x128 == UInt[128](1) or c == e:
return False
return True
def _pollard_rho(N: int) -> int: #Nの素因数を探索 ミラーラビンを参照する
assert N > 0
if N & 1 == 0:
return 2
if N == 1 or prime._miller_rabin(N):
return N
while True:
N128 = Int[128](N)
step = int(N ** 0.125) + 1
for c in range(1, N):
#f(n) = n ** 2 + c mod N と疑似乱数を定義する
#y128 = f^{s}(0), z128: Π(x128 - y128) mod N128
#g: gcd(x, y) t: sの次の目標となる2冪
y128, z128, c128 = Int[128](0), Int[128](1), Int[128](c)
g, s, t = 1, 0, 1
while g == 1:
x128 = y128
nxt_s = (3 * t) >> 2
for _ in range(nxt_s - s):
y128 = (y128 * y128 + c128) % N128 #y ← f(y)
s = nxt_s
while s < t and g == 1:
backtrack128 = y128
for _ in range(min(step, t - s)): #N ** 1/8回まとめてgcdを計算
y128 = (y128 * y128 + c128) % N128 #y ← f(y)
z128 = z128 * (x128 - y128) % N128
g, h = N, abs(int(z128))
while h: #g ← gcd(N, z128)
g, h = h, g % h
s += step
s, t = t, t << 1
if g == N:
g, y128 = 1, backtrack128
while g == 1:
y128 = (y128 * y128 + c128) % N128 #y ← f(y)
g, h = abs(int(x128 - y128)), N
while h: #g ← gcd(N, x128 - y128)
g, h = h, g % h
if g == N:
continue #検出失敗
if prime._miller_rabin(g):
return g
elif prime._miller_rabin(N // g):
return N // g
else:
N = g
break #while Trueへ
def _fast_fact(N: int) -> list[tuple[int, int]]:
assert N >= 1
ans: list[tuple[int, int]] = []
if N & 1 == 0:
ans.append((2, N.__cttz__()))
N >>= N.__cttz__()
p2 = 1
for p in range(3, int(N ** 0.25), 2): #O(N ** 1/4)回のためし割り
p2 += (p - 1) << 2 #assert p * p == p2
if p2 > N:
if N > 1:
ans.append((N, 1))
N = 1
break
if N % p == 0:
e = 0
while N % p == 0:
N //= p
e += 1
ans.append((p, e))
while N > 1:
p = prime._pollard_rho(N)
e = 0
while N % p == 0:
N //= p
e += 1
ans.append((p, e))
ans.sort()
return ans
def _enumerate_divisor(N: int) -> list[int]:
F: list[tuple[int, int]] = prime._fast_fact(N)
Rt: int = 1
for _, e in F:
Rt *= e + 1
D: list[int] = [1] * Rt
Rt: int = 1
for p, e in F:
for Lt in range(Rt * e):
D[Rt] = D[Lt] * p
Rt += 1
D.sort()
return D
#素数判定
def is_prime(N: int) -> bool:
'''
ミラーラビン素数判定法により素数判定を行います。
計算量: int128の剰余演算の計算量をO(L)としたとき、O(7L * logN)
制約: 1 <= N < 2 ** 63
'''
assert 1 <= N
return prime._miller_rabin(N)
#O(N ** 1/4) 高速素因数分解
def factorize(N: int) -> list[tuple[int, int]]:
'''
Nを素因数分解し、(素因数, 次数) の形のリストとして返します。
期待計算量: int128の剰余演算をO(L)としたとき、O(L * N ** 1/4)
制約: 1 <= N < 2 ** 63
'''
assert 1 <= N
return prime._fast_fact(N)
#約数列挙
def divisor(N: int) -> list[int]:
'''
Nの約数を列挙し、ソートして返します。
期待計算量: 約数の個数をdとしたとき、prime.factorize + O(d * logd)
制約: 1 <= N < 2 ** 63
'''
assert 1 <= N
return prime._enumerate_divisor(N)
SortedSet・SortedList
tatyamさんの実装のほうが優秀なので、そちらを利用してください。
SkipListを用いた期待$O(logN)$の実装をしてみたのですが、遅すぎてだめでした
内部実装もぐちゃくちゃなので手の施しようがありません。供養として置いておきます
SortedSet for codon
SortedList for codon
Wavelet Matrix
Wavelet Matrix for codon・PyPy3
#Wavelet Matrix for codon, PyPy3
import heapq as _WM_heapq
class WaveletMatrix:
'''
Wavelet Matrix for codon, PyPy3
非負整数列Aに対する検索を対数時間で行います。
bit_countの実装上、すべての操作に O(log wordsize) = O(3) の項がかかります。
N = len(A), M = max(A) として
構築: 時間 O(N logM), 空間O(N)
検索: 時間 O(logM) ~
A: 読み込ませたい非負整数列
'''
_N: int
_logM: int
_size: int
_C: list[int]
_D: list[int]
_zero: list[int]
_one: list[int]
_stack: list[int]
__slots__ = ('_N', '_logM', '_size', '_C', '_D', '_zero', '_one', '_stack')
def __init__(self, A: list[int]) -> None:
assert len(A) == 0 or min(A) >= 0, f'Aに負の要素が含まれます。{min(A) = }'
assert len(A) < 2 ** 29, f'len(A)が長すぎます。 {len(A) = }'
self._N = N = len(A)
maxA: int = 0 if len(A) == 0 else max(A)
self._logM = logM = 0 if maxA == 0 else len(bin(maxA)) - 2
self._size = size = -(- N >> 5)
self._C = C = [0] * size * logM #FIDをlogM個作成
self._zero: list[int] = [0] * logM
self._one: list[int] = [0] * logM
self._stack: list[int] = [0] * ((logM + 1) << 1)
D: list[int] = list(range(N))
E: list[int] = [0] * N #DとEをswapしながら上の桁から決定
for k in range(logM - 1, -1, -1):
offset: int = size * k
zero = one = now = 0
for b in range(offset, offset + size):
C[b] = one << 32
for c in range(32):
if now >= N:
break
v: int = A[D[now]] >> k & 1
if v == 0:
zero += 1
else:
one += 1
C[b] |= 1 << c
now += 1
Lt, Rt = 0, zero
for D_now in D:
if A[D_now] >> k & 1 == 0:
E[Lt] = D_now
Lt += 1
else:
E[Rt] = D_now
Rt += 1
self._zero[k], self._one[k] = zero, one
D, E = E, D
self._D: list[int] = D
#内部関数: FID
def _FID_access(self, k: int, i: int) -> int: #FID[k]に対し、B[i] >> k & 1
return self._C[self._size * k + (i >> 5)] >> (i & 31) & 1
def _FID_rank(self, k: int, i: int, num: int) -> int: #[0, i]のnumの個数
if i < 0:
return 0
Ci: int = self._C[self._size * k + (i >> 5)]
n: int = Ci & ~(-1 << ((i & 31) + 1)) #one = (Ci >> 32) + n.bit_count()
n: int = ( n & 0x55555555 ) + ( (n >> 1) & 0x55555555 )
n: int = ( n & 0x33333333 ) + ( (n >> 2) & 0x33333333 )
n: int = ( n & 0x0F0F0F0F ) + ( (n >> 4) & 0x0F0F0F0F )
one: int = (Ci >> 32) + (n * 0x1010101 >> 24 & 63)
return one if num == 1 else 1 + i - one
def _FID_stable_sort(self, k: int, i: int) -> int: #安定ソート後のiの位置
num: int = self._C[self._size * k + (i >> 5)] >> (i & 31) & 1 #access(i)
if num == 0:
return self._FID_rank(k, i - 1, 0)
else:
return self._FID_rank(k, i - 1, 1) + self._zero[k]
def _FID_range_sort(self, k: int, Lt: int, Rt: int, num: int) -> tuple[int, int]:
offset: int = 0 if num == 0 else self._zero[k]
return (offset + self._FID_rank(k, Lt - 1, num),
offset + self._FID_rank(k, Rt - 1, num))
#基本機能: 計算量がO(logM)
def access(self, i: int) -> int:
'A[i]をO(logM)で取得します。'
if i < 0:
i += self._N
assert 0 <= i < self._N
ans: int = 0
for k in range(self._logM - 1, -1, -1):
b: int = self._FID_access(k, i)
ans |= b << k
i: int = self._FID_stable_sort(k, i)
return ans
def rank(self, Lt: int, Rt: int, value: int) -> int:
'A[Lt, Rt)のvalueの出現回数をO(logM)で取得します。'
assert 0 <= Lt <= Rt <= self._N
if value < 0 or value >> self._logM >= 1:
return 0
for k in range(self._logM - 1, -1, -1):
Lt, Rt = self._FID_range_sort(k, Lt, Rt, value >> k & 1)
return Rt - Lt
def select(self, cnt: int, value: int) -> int:
'''
0-indexedでcnt個目のvalueの添字をO(logM)で取得します。
特に、cnt = 0 かつ value in A の時は A[Lt: Rt].index(value) と返り値が一致します。
cnt >= A.count(value) の場合、Nを返します。
'''
if value < 0 or value >> self._logM >= 1:
return self._N
Lt, Rt = 0, self._N
for k in range(self._logM - 1, -1, -1):
Lt, Rt = self._FID_range_sort(k, Lt, Rt, value >> k & 1)
if cnt >= Rt - Lt:
return self._N
else:
return self._D[Lt + cnt]
def kth_min(self, Lt: int, Rt: int, k: int) -> int:
'sorted( A[Lt, Rt) )[k] : A[Lt, Rt)の小さい側からk番目の要素 をO(logM)で取得します。'
assert 0 <= Lt <= Rt <= self._N
assert 0 <= k < Rt - Lt, f'k is out of range: {Rt - Lt = }, {k = }'
cnt: int = k #内部的に添字をk → cntに変更
ans: int = 0
for k in range(self._logM - 1, -1, -1):
Li: int = self._FID_rank(k, Lt - 1, 0)
Ri: int = self._FID_rank(k, Rt - 1, 0)
zero: int = Ri - Li
if cnt < zero:
Lt, Rt = Li, Ri
else: #Lt, Rt = self._FID_range_sort(k, Lt, Rt, 1)
ans |= 1 << k
cnt -= zero
offset: int = self._zero[k]
Lt, Rt = offset + (Lt - Li), offset + (Rt - Ri)
return ans
def range_freq(self, Lt: int, Rt: int, vL: int, vR: int) -> int:
'A[Lt, Rt)に存在する、 vL <= Ai < vR を満たすAiの個数をO(logM)で取得します。'
assert 0 <= Lt <= Rt <= self._N
if not vL < vR:
return 0
ans: int = Rt - Lt
stack: list[int] = self._stack
if vL > 0:
stack[0], stack[1] = 0, self._logM << 58 | Lt << 29 | Rt
d: int = 2
while d:
c, x = stack[d - 2], stack[d - 1]
d -= 2
k, Li, Ri = x >> 58, x >> 29 & 0x1FFFFFFF, x & 0x1FFFFFFF
if c + (1 << k) <= vL:
ans -= Ri - Li
continue
k -= 1
if k == -1:
break
for b in (1, 0):
Lj, Rj = self._FID_range_sort(k, Li, Ri, b)
if Lj != Rj:
stack[d], stack[d + 1] = c | b << k, k << 58 | Lj << 29 | Rj
d += 2
if vR <= ~(-1 << self._logM):
stack[0], stack[1] = 0, self._logM << 58 | Lt << 29 | Rt
d: int = 2
while d:
c, x = stack[d - 2], stack[d - 1]
d -= 2
k, Li, Ri = x >> 58, x >> 29 & 0x1FFFFFFF, x & 0x1FFFFFFF
if c >= vR: #変更点
ans -= Ri - Li
continue
k -= 1
if k == -1:
break
for b in (0, 1): #変更点
Lj, Rj = self._FID_range_sort(k, Li, Ri, b)
if Lj != Rj:
stack[d], stack[d + 1] = c | b << k, k << 58 | Lj << 29 | Rj
d += 2
return ans
def prev_value(self, Lt: int, Rt: int, value: int) -> int:
'''
A[Lt, Rt)のうち、valueより真に小さい最大値をO(logM)で取得します。
そのような値が存在しない場合、-1を返します。
'''
assert 0 <= Lt <= Rt <= self._N
if Lt == Rt or value <= 0:
return -1
stack: list[int] = self._stack
stack[0], stack[1] = 0, self._logM << 58 | Lt << 29 | Rt
d: int = 2
while d:
c, x = stack[d - 2], stack[d - 1]
d -= 2
k, Li, Ri = x >> 58, x >> 29 & 0x1FFFFFFF, x & 0x1FFFFFFF
if c >= value: #変更点
continue
k -= 1
if k == -1:
return c
for b in (0, 1): #変更点
Lj, Rj = self._FID_range_sort(k, Li, Ri, b)
if Lj != Rj:
stack[d], stack[d + 1] = c | b << k, k << 58 | Lj << 29 | Rj
d += 2
else:
return -1
def next_value(self, Lt: int, Rt: int, value: int) -> int:
'''
A[Lt, Rt)のうち、valueより真に大きい最小値をO(logM)で取得します。
そのような値が存在しない場合、-1を返します。
'''
assert 0 <= Lt <= Rt <= self._N
if Lt == Rt or value >= ~(-1 << self._logM):
return -1
value += 1
stack: list[int] = self._stack
stack[0], stack[1] = 0, self._logM << 58 | Lt << 29 | Rt
d: int = 2
while d:
c, x = stack[d - 2], stack[d - 1]
d -= 2
k, Li, Ri = x >> 58, x >> 29 & 0x1FFFFFFF, x & 0x1FFFFFFF
if c + (1 << k) <= value:
continue
k -= 1
if k == -1:
return c
for b in (1, 0):
Lj, Rj = self._FID_range_sort(k, Li, Ri, b)
if Lj != Rj:
stack[d], stack[d + 1] = c | b << k, k << 58 | Lj << 29 | Rj
d += 2
else:
return -1
#基本機能: 計算量がO(logM)でないもの
def topk_mode(self, Lt: int, Rt: int, k: int) -> list[tuple[int, int]]:
'''
A[Lt, Rt)の頻度を数え、E: [(値, 個数) のリスト]を作成します。
その後 個数の降順・同率なら値の昇順 にEをソートし、E[:cnt]を返します。
計算量は O( cnt * logNlogM ) です。返り値のタプルの順序は(値, 個数)です。
'''
assert 0 <= Lt <= Rt <= self._N
ans: list[tuple[int, int]] = []
cnt: int = k #内部的に添字をk → cntに変更
if Lt == Rt or cnt <= 0:
return ans
Q: list[tuple[int, int, int]] = [
( ~( (Rt - Lt) << 32 | self._logM), 0, Lt << 32 | Rt )]
while Q and len(ans) < cnt:
x, y, z = _WM_heapq.heappop(Q)
w, k = (~ x) >> 32, ((~ x) & 0xFFFFFFFF) - 1
Li, Ri = z >> 32, z & 0xFFFFFFFF
assert Li <= Ri and 0 <= w == Ri - Li and k >= -1
if k == -1:
ans.append((y, Ri - Li))
continue
for b in range(2):
Lj, Rj = self._FID_range_sort(k, Li, Ri, b)
if Rj > Lj:
_WM_heapq.heappush(
Q, (~ ((Rj - Lj) << 32 | k), y | b << k, Lj << 32 | Rj ))
return ans
def intersect(self, L1: int, R1: int, L2: int, R2: int) -> list[tuple[int, int]]:
'''
A[L1, R1) と A[L2, R2) の共通要素を取り出し、(値, 個数) の昇順で返します。
計算量は O( (R - L)logM )です。
'''
assert 0 <= L1 <= R1 <= self._N and 0 <= L2 <= R2 <= self._N
ans: list[tuple[int, int]] = []
if L1 == R1 or L2 == R2:
return ans
stack: list[int] = self._stack
while len(stack) < (self._logM + 1) * 3:
stack.append(0)
stack[0], stack[1], stack[2] = 0, self._logM << 58 | L1 << 29 | R1, L2 << 29 | R2
d: int = 3
while d:
c, y, z = stack[d - 3], stack[d - 2], stack[d - 1]
d -= 3
k, L1i, R1i = (y >> 58) - 1, y >> 29 & 0x1FFFFFFF, y & 0x1FFFFFFF
L2i, R2i = z >> 29 & 0x1FFFFFFF, z & 0x1FFFFFFF
if k == -1:
ans.append((c, min(R1i - L1i, R2i - L2i)))
continue
for b in (1, 0):
L1j, R1j = self._FID_range_sort(k, L1i, R1i, b)
L2j, R2j = self._FID_range_sort(k, L2i, R2i, b)
if L1j != R1j and L2j != R2j:
stack[d] = c | b << k
stack[d + 1] = k << 58 | L1j << 29 | R1j
stack[d + 2] = L2j << 29 | R2j
d += 3
return ans
行列累乗
行列累乗 for codon, PyPy3
#行列累乗 for codon, PyPy3
class matrix_power:
'''
行列累乗 for codon, PyPy3
法MODの下、行列計算を行います。
行列は2次元リストとして渡してください。
返り値は新しい2次元リストで、全成分が0以上MOD未満を満たします。
MOD: 法
'''
MOD: int
_acc_limit: int
__slots__ = ('MOD', '_acc_limit')
def __init__(self, MOD: int) -> None:
self.MOD = MOD
if MOD > 3037000500: #3_037_ + 3_037 ** 2 >= 2 ** 63
self._acc_limit: int = 1
else: #(MOD - 1) + _acc_limit * (MOD - 1) ** 2 < 2 ** 63
self._acc_limit: int = (~(-1 << 63) - (MOD - 1)) // ((MOD - 1) ** 2)
#内部関数
def _matrix_add(self, A: list[list[int]], B: list[list[int]],
H: int, W: int) -> list[list[int]]:
C: list[list[int]] = [[0] * W for _ in range(H)]
for h in range(H):
Ah, Bh, Ch = A[h], B[h], C[h]
for w, Ahw in enumerate(Ah):
Ch[w] = (Ahw % self.MOD) + (Bh[w] % self.MOD)
if Ch[w] >= self.MOD:
Ch[w] -= self.MOD
while Ch[w] < 0:
Ch[w] += self.MOD
return C
def _matrix_mul(self, A: list[list[int]], B: list[list[int]], C: list[list[int]],
H: int, W: int, X: int) -> list[list[int]]:
#すべての成分が 0 <= A[h][x], B[x][w] < MOD を満たすことを要求する
B_w: list[int] = [0] * X
for w in range(W): #C[h][w] = sum(A[h][x] * B[x][w] for all x)
for x in range(X):
B_w[x] = B[x][w]
for h in range(H):
cnt: int = 0
k: int = self._acc_limit
for x, Ahx in enumerate(A[h]):
cnt += Ahx * B_w[x]
k -= 1
if k == 0:
cnt %= self.MOD
k = self._acc_limit
C[h][w] = cnt % self.MOD
return C
def _matrix_doubling_mul(self, A: list[list[int]], N: int, k: int) -> list[list[int]]:
if k == 0:
return self.eye(N)
C: list[list[int]] = [[0] * N for _ in range(N)]
for h in range(N):
Ch: list[int] = C[h]
for w, Ahw in enumerate(A[h]):
Ch[w] = Ahw % self.MOD
if Ch[w] < 0:
Ch[w] += self.MOD
if k == 1:
return C
D: list[list[int]] = [[0] * N for _ in range(N)]
E: list[list[int]] = [C[h][:] for h in range(N)]
for i in range(len(bin(k)) - 4, -1, -1):
self._matrix_mul(C, C, D, N, N, N)
C, D = D, C
if k >> i & 1 == 1:
self._matrix_mul(C, E, D, N, N, N)
C, D = D, C
return C
#基本機能
def eye(self, N: int) -> list[list[int]]:
'N行N列の単位行列を返します。'
A: list[list[int]] = [[0] * N for _ in range(N)]
for i in range(N):
A[i][i] = 1
return A
def add(self, A: list[list[int]], B: list[list[int]]) -> list[list[int]]:
'行列C := A + B を新しく生成します。'
assert len(A) == len(B)
if len(A) == 0:
return []
H: int = len(A)
W: int = len(A[0])
assert all(len(Ai) == W for Ai in A) and all(len(Bi) == W for Bi in B)
return self._matrix_add(A, B, H, W)
def mul(self, A: list[list[int]], B: list[list[int]]) -> list[list[int]]:
'H行X列の行列Aと、X行W列の行列Bから、行列C := A * Bを新しく作成します。'
H: int = len(A)
if H == 0:
return []
X: int = len(A[0])
assert all(len(Ai) == X for Ai in A) and len(B) == X
if X == 0:
return [[] for _ in range(H)]
W: int = len(B[0])
assert all(len(Bi) == W for Bi in B)
new_A: list[list[int]] = [[0] * X for _ in range(H)]
new_B: list[list[int]] = [[0] * W for _ in range(X)]
for h in range(H):
new_Ah: list[int] = new_A[h]
for x, Ahx in enumerate(A[h]):
new_Ah[x] = Ahx % self.MOD
if new_Ah[x] < 0:
new_Ah[x] += self.MOD
for x in range(X):
new_Bx: list[int] = new_B[x]
for w, Bxw in enumerate(B[x]):
new_Bx[w] = Bxw % self.MOD
if new_Bx[w] < 0:
new_Bx[w] += self.MOD
return self._matrix_mul(new_A, new_B, [[0] * W for _ in range(H)], H, W, X)
def doubling_mul(self, A: list[list[int]], k: int) -> list[list[int]]:
'正方行列Aから、正方行列C := A ** k を新しく作成します。'
N: int = len(A)
if N == 0:
return []
assert all(len(Ai) == N for Ai in A)
assert k >= 0
return self._matrix_doubling_mul(A, N, k)
おわりに
おわりです。
codonの開拓が進んだらうれしいです。