4
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【競プロ】Python codonの使い方(仮) テンプレ・ライブラリ編

Last updated at Posted at 2025-11-04

AtCoderジャッジアップデートで解禁されたcodonを実際に使ってみましょう。

本記事はテンプレ・ライブラリ編です。
変更点編は以下のリンクからご覧ください。

更新履歴

2025/11/08 ライブラリにstringと行列累乗を追加しました。
2025/11/19 ライブラリに最小費用流を追加しました。
2025/11/21 ライブラリに畳み込みを追加しました。
2025/11/25

  • テンプレにint ハッシュ値変更を追加しました。
  • ライブラリにfloor_sum・disjoint Sparse Table・Wavelet Matrixを追加しました。

変更点のまとめ 3行で

  • codonは型厳格です。Pythonの型ヒントの構文で型を指定しましょう。
  • intが符号つき64bit整数になりました。オーバーフローに注意してください。
  • 機能は@extendで追加・変更できます。足りない機能はDIYしましょう。

PyPy3のコードをcodonで動かす手順提案

先述の変更点に注意しながらコーディング・・・というのは少々大変なので、ここでは最小限の変更で、PyPy3のコードをcodonに移植する実験手順を提案します。
まずは動けばラッキーくらいの気持ちで試してみてください。

  1. PyPy3のジャッジ結果が AC・TLE・MLE となるコードを用意します。
    WA・REが出る場合は先に修正してください。

  2. 必ず 移植前にPyPy3のコードを保存してください。
    codonが動かなかった場合のコード復元ができるようにしてください。

  3. コードの冒頭に、後述のテンプレを貼ってください。

  4. PyPy3のライブラリを使用する場合、インスタンス変数の型指定をしてください。
    型はクラス変数と同様の位置に書きます。分からなければNoneを代入してみてください。
    既存のcodon向けライブラリに差し替えてもよいでしょう。

  5. コードテストにかけてみて、動けばラッキーです。動かなければ諦めてください。

テンプレ例
テンプレ例
#mapの返り値をlist(map)で固定
import internal.static as _internal_static
def map(f, *args) -> list:
    if _internal_static.len(args) == 0:
        compile_error("map() expects at least one iterator")
    elif _internal_static.len(args) == 1:
        return [f(a) for a in args[0]]
    else:
        return [f(*a) for a in zip(*args)]

#int同士の除算結果をPythonの負の無限大丸めに合わせる
@extend
class int:
    @pure
    @llvm
    def _floordiv_int_int(self: int, other: int) -> int:
        %0 = sdiv i64 %self, %other
        ret i64 %0
    @overload
    def __floordiv__(self, other: int):
        d = self._floordiv_int_int(other)
        m = self - d * other
        if m and ((other ^ m) < 0):
            d -= 1
        return d
    @pure
    @llvm
    def _mod_int_int(self: int, other: int) -> int:
        %0 = srem i64 %self, %other
        ret i64 %0
    @overload
    def __mod__(self, other: int) -> int:
        m = self._mod_int_int(other)
        if m and ((other ^ m) < 0):
            m += other
        return m

#Int[N](N <= 128)同士の除算結果をPythonの負の無限大丸めに合わせる
@extend
class Int:
    def __floordiv__(self, other: Int[N]) -> Int[N]:
        if N > 128:
            compile_error("division is not supported on Int[N] when N > 128")
        d = self._floordiv(other)
        m = self - d * other
        if m and ((other ^ m) < Int[N](0)):
            d -= Int[N](1)
        return d
    def __mod__(self, other: Int[N]) -> Int[N]:
        if N > 128:
            compile_error("modulus is not supported on Int[N] when N > 128")
        m = self._mod(other)
        if m and ((other ^ m) < Int[N](0)):
            m += other
        return m

#int.bit_length, int.bit_countに対応
@extend
class int:
    def bit_length(self): 
        return 64 - abs(self).__ctlz__()
    def bit_count(self):
        return abs(self).__ctpop__()

#floatの出力桁数を15桁に増やす
@extend
class float:
    def __str__(self): return f'{self:15f}'

#巨大mod時のオーバーフローを回避  pow(base, -1, mod)に対応
def _extended_pow():
    _builtin_pow = pow
    def _codon_pow(base, exp):
        return _builtin_pow(base, exp)
    @overload
    def _codon_pow(base: int, exp: int, mod: int) -> int:
        '''
        codon用に pow(base, exp, mod) を拡張した関数です。
        1. (abs(mod) - 1) ** 2 >= 1 << 63 の場合に発生していたオーバーフローを回避しました。
        2. pow(base: int, exp: 負整数, mod: int) による逆元計算に対応しました。

        返り値の符号は mod の符号に一致します。
        いずれかの引数に INF_MIN := -1 << 63 を渡した場合の動作は未定義です。
        '''
        if mod == 0:
            raise ValueError('pow() 3rd argument cannot be 0')
        if mod == 1 or mod == -1:
            return 0
        if exp < 0:  #拡張ユークリッドの互除法
            a, b, x, y = base, mod, 1, 0
            while b:
                q = a // b
                a, b, x, y = b, a - q * b, y, x - q * y
            if a != 1 and a != -1:
                raise ValueError('base is not invertible for the given modulus')
            b128, m128 = Int[128](x), Int[128](mod)
            if a == -1:
                b128 = - b128
            exp = - exp
        else:
            b128, m128 = Int[128](base), Int[128](mod)
        v128 = Int[128](1)
        while exp:
            if exp & 1 == 1:
                v128 = v128 * b128 % m128
            b128 = b128 * b128 % m128
            exp >>= 1
        v = int(v128)
        return v + mod if v != 0 and (0 < v) != (0 < mod) else v
    return _codon_pow
pow = _extended_pow()

もう少しだけ粘りたい場合、以下の点をチェックしてみてもいいですが、無理なときは無理なのでほどほどのところで切り上げてください。

  • sys.stdin.readlinesys.setrecursionlimitといった、codon非対応のものは削除してください。
  • match / case文や:=(セイウチ演算子)はバグの原因なので差し替えてください。
  • オーバーフローしていそうな場合、intの代わりにi128を使うと改善するかもしれません。ただし、計算量保証はなくなります。
  • setやdictにタプルを入れたい場合は、タプルの代わりにリストのまま入れてください。
  • codonのsetやdictはハッシュの衝突が起きやすいです。TLEが取れない場合、補足の章にあるintハッシュ変更の拡張機能を追加してください。

補足: テンプレについて

codonは@extend@overloadで多くの機能を変更できます。
筆者の改変例をいくつか提示します。使えそうなものだけ持って行ってください。

map入力受取

map(int, input().split())で入力を受け取れるようにします。
本来は手動でlist(map(int, input().split())) に差し替えるのがベストですが、少し面倒なので2通りの対応例を提示します。
どちらか片方を選んでご利用ください。

対応例1. generatorを改変
#generatorに__getitem__, __contains__を定義
@extend
class Generator:
    def __getitem__(self: Generator[T], _: int) -> T:
        if self.done(): raise StopIteration()
        return self.next()
    def __getitem__(self: Generator[T], key: slice) -> list[T]:
        assert key.stop is None, (
            '''本拡張では、末尾以外でのアンパッキングはできません。
            例として、 N, *A = generator には対応しますが *A, N = generator は非対応です。
            右辺の generator を list(generator) と書き換えてみてください。''')
        return list(self)
    def __contains__(self: Generator[T], key: T) -> bool:
        return key in list(self)
対応例2. mapの返り値をlist(map)で固定
#mapの返り値をlist(map)で固定
import internal.static as _internal_static
def map(f, *args) -> list:
    if _internal_static.len(args) == 0:
        compile_error("map() expects at least one iterator")
    elif _internal_static.len(args) == 1:
        return [f(a) for a in args[0]]
    else:
        return [f(*a) for a in zip(*args)]

整数 除算方向の変更

Pythonと同様の負の無限大丸めに変更します。
intInt(符号つき任意倍長整数)の対応例を示します。

intの除算方向の変更
#int同士の除算結果をPythonの負の無限大丸めに合わせる
@extend
class int:
    @pure
    @llvm
    def _floordiv_int_int(self: int, other: int) -> int:
        %0 = sdiv i64 %self, %other
        ret i64 %0
    @overload
    def __floordiv__(self, other: int):
        d = self._floordiv_int_int(other)
        m = self - d * other
        if m and ((other ^ m) < 0):
            d -= 1
        return d
    @pure
    @llvm
    def _mod_int_int(self: int, other: int) -> int:
        %0 = srem i64 %self, %other
        ret i64 %0
    @overload
    def __mod__(self, other: int) -> int:
        m = self._mod_int_int(other)
        if m and ((other ^ m) < 0):
            m += other
        return m
Int(任意倍長) の除算方向の変更
#Int[N](N <= 128)同士の除算結果をPythonの負の無限大丸めに合わせる
@extend
class Int:
    def __floordiv__(self, other: Int[N]) -> Int[N]:
        if N > 128:
            compile_error("division is not supported on Int[N] when N > 128")
        d = self._floordiv(other)
        m = self - d * other
        if m and ((other ^ m) < Int[N](0)):
            d -= Int[N](1)
        return d
    def __mod__(self, other: Int[N]) -> Int[N]:
        if N > 128:
            compile_error("modulus is not supported on Int[N] when N > 128")
        m = self._mod(other)
        if m and ((other ^ m) < Int[N](0)):
            m += other
        return m

bit_count, bit_length

intクラスには__ctlz__, __cttz__, __ctpop__の命令が追加されているので、これを利用して実装します。
なお、PythonやPyPyのbit_count・bit_lengthと異なり非常に高速です。

bit_count, bit_length
#int.bit_length, int.bit_countに対応
@extend
class int:
    def bit_length(self): 
        return 64 - abs(self).__ctlz__()
    def bit_count(self):
        return abs(self).__ctpop__()

int ハッシュ値変更

(2025/11/25 追記)
調査により、codonのset・dictはハッシュの衝突に脆弱だと判明しました。
原因の詳細は省きますが特に2冪の入力に弱く、例として$2^k$の倍数の入力を行うだけでハッシュが完全に衝突してしまいます。
なのでハッシュの衝突が目立つ場合はハッシュ値のアルゴリズムを変更して対応しましょう。以下はSplitMix64を用いた実装例です。

intハッシュ値変更
#int hash値をSplitMix64で変更する
#Reference: https://prng.di.unimi.it/splitmix64.c
@extend
class int:
    def __hash__(self) -> int:
        z: UInt[64] = UInt[64](self) + UInt[64](0x9e3779b97f4a7c15)
        z = (z ^ (z >> UInt[64](30))) * UInt[64](0xbf58476d1ce4e5b9)
        z = (z ^ (z >> UInt[64](27))) * UInt[64](0x94d049bb133111eb)
        return int(z ^ (z >> UInt[64](31)))

float 出力桁数増加

print(float)での出力桁数はf stringで変更できます。
なお、float128の出力桁数変更は現在非対応です。

float format
#floatの出力桁数を15桁に増やす
@extend
class float:
    def __str__(self): return f'{self:15f}'

pow オーバーフロー回避・逆元計算追加

codonのpowはオーバーフローするうえ、pow(base, -1, mod)の逆元計算にも非対応でやや不便です。
早速改造しましょう。

この実装例では、2引数ならビルトインpowを呼び出し、3引数ならオリジナルpowを呼び出すように工夫しています。
また、繰り返し二乗法で$O(log exp)$回の128bit除算を行います。手元でテストする限りでは十分高速でしたが、この点ご留意ください。

pow
#巨大mod時のオーバーフローを回避  pow(base, -1, mod)に対応
def _extended_pow():
    _builtin_pow = pow
    def _codon_pow(base, exp):
        return _builtin_pow(base, exp)
    @overload
    def _codon_pow(base: int, exp: int, mod: int) -> int:
        '''
        codon用に pow(base, exp, mod) を拡張した関数です。
        1. (abs(mod) - 1) ** 2 >= 1 << 63 の場合に発生していたオーバーフローを回避しました。
        2. pow(base: int, exp: 負整数, mod: int) による逆元計算に対応しました。

        返り値の符号は mod の符号に一致します。
        いずれかの引数に INF_MIN := -1 << 63 を渡した場合の動作は未定義です。
        '''
        if mod == 0:
            raise ValueError('pow() 3rd argument cannot be 0')
        if mod == 1 or mod == -1:
            return 0
        if exp < 0:  #拡張ユークリッドの互除法
            a, b, x, y = base, mod, 1, 0
            while b:
                q = a // b
                a, b, x, y = b, a - q * b, y, x - q * y
            if a != 1 and a != -1:
                raise ValueError('base is not invertible for the given modulus')
            b128, m128 = Int[128](x), Int[128](mod)
            if a == -1:
                b128 = - b128
            exp = - exp
        else:
            b128, m128 = Int[128](base), Int[128](mod)
        v128 = Int[128](1)
        while exp:
            if exp & 1 == 1:
                v128 = v128 * b128 % m128
            b128 = b128 * b128 % m128
            exp >>= 1
        v = int(v128)
        return v + mod if v != 0 and (0 < v) != (0 < mod) else v
    return _codon_pow
pow = _extended_pow()

補足: ライブラリについて

PyPy3のライブラリに型を追加すれば大抵は動くようになりますが、一部移植が難しいものも存在します。
ここでは移植の参考として、筆者の実装例を提示します。

注意点

  • PyPy3の自作ライブラリの移植です。説明のために関数名だけはACL風に寄せてみましたが、内部実装は全く異なります。
  • ACLにない機能も含まれます。
  • 遅いものも含まれます。特に、SortedSet・SortedListはものすごく遅いです。
  • 最低限のランダムテストしか行っていません。バグはご容赦ください。
  • これからACLを移植する方は、PythonからではなくC++から移植した方が速度が出ると思います。あとgithubとかを使った方がいいです

参考文献

UnionFind

for codon・PyPy3

Fenwick Tree

for codon・PyPy3

for codon

Segment Tree

for codon

Lazy Segment Tree

実装はアルゴリズム実技検定 公式テキスト[上級]~[エキスパート]編に影響を受けています。
合成関数の方向や木内二分探索の定義が間違っているかもしれません。

for codon

disjoint Sparse Table

disjoint Sparse Table for codon
disjoint Sparse Table for codon
#disjoint Sparse Table for codon
class disjointSparseTable[Te, Tf]:
    '''
    disjoint Sparse Table for codon
    Θ(NlogN)の前計算の上で、O(1)で区間積を計算します。

    A: 読み込ませる配列
    identity_e: 単位元 要素の型はAと同じにしてください
    node_f: 合成関数 f(node_Lt: Te, node_Rt: Te) -> node_new: Te
    '''
    N: int
    _e: Te
    _f: Tf
    _node: list[Te]
    __slots__ = ('N', '_e', '_f', '_node')
    def __init__(self, A: list[Te], identity_e: Te, node_f: Tf) -> None:
        self.N = N = len(A)
        logN: int = max(1, len(bin(N - 1)) - 2)  #(N - 1).bit_length()
        self._e, self._f = identity_e, node_f
        self._node = node = [self._e for _ in range(N * logN)]
        for h in range(logN):
            offset: int = h * N
            for i, Ai in enumerate(A, start = offset):
                node[i] = Ai
            b = diff = 1 << h
            step: int = 2 << h
            while b < N:
                node[b + offset] = back = A[b]
                i: int = b + 1
                Rt: int = min(b + diff, N)
                while i < Rt:
                    node[i + offset] = back = self._f(back, A[i])
                    i += 1
                b += step
            b: int = diff - 1
            while b < N:
                node[b + offset] = back = A[b]
                i: int = b - 1
                Lt: int = b - diff
                while Lt < i:
                    node[i + offset] = back = self._f(A[i], back)
                    i -= 1
                b += step
    def fold(self, Lt: int, Rt: int) -> Te:
        '半開区間積A[Lt, Rt)を取得します。Lt == Rtの場合、単位元eを返します。'
        assert 0 <= Lt <= Rt <= self.N
        if Lt == Rt:
            return self._e
        Rt -= 1
        if Lt == Rt:
            return self._node[Lt]
        h: int = 63 - (Lt ^ Rt).__ctlz__()  #h ← (Lt ^ Rt).bit_length() - 1
        return self._f( self._node[h * self.N + Lt], self._node[h * self.N + Rt] )

SCC

for codon・PyPy3

最大流

for codon

最小費用流

本実装はPyPy3の実装を流用しており、ダイクストラ法にセグメント木を用いています。
ですがcodonはheapqが十分に高速なので、ライブラリ整備の際はheapqでの実装をおすすめします。

for codon

二部グラフマッチング

for codon

suffix array, Z algorithm

内部実装は特に荒れています。

for codon, PyPy3

畳み込み

for codon, PyPy3

math

floor_sumは未完成で、特にオーバーフローの挙動が不安定です。

isqrt, inv_mod, gcd, ext_gcd, CRT for codon
math for codon
#floor(√n)
def isqrt(n: int) -> int:
    'floor(√n): m ** 2 <= n < (m + 1) ** 2 を満たす非負整数mを求めます。'
    assert n >= 0
    if n >= 3037000499 ** 2:  #floor( √(2 ** 63 - 1) ) = 3037000499
       return 3037000499
    m: int = max(0, int(float(n).sqrt()))  #int(n ** 0.5)
    m2: int = m * m
    while m2 < n:
        m2 += m << 1 | 1
        m += 1
    while m2 > n:
        m -= 1
        m2 -= m << 1 | 1
    return m

#pow(base, -1, mod)
def inv_mod(base: int, mod: int) -> int:
    'pow(base, -1, mod) を求めます。返り値の符号はmodの符号と一致します。'
    assert mod != 0, f'mod must not be zero. {mod = }'
    if mod == 1 or mod == -1:
        return 0
    a, b, x, y = base, mod, 1, 0
    while b:
        q = a // b
        a, b, x, y = b, a - q * b, y, x - q * y
    if a != 1 and a != -1:
        raise ValueError('base is not invertible for the given modulus')
    if a == -1:
        x = - x
    return x + mod if (x ^ mod) < 0 else x

#最大公約数
gcd = lambda x, y: gcd(y, x % y) if y else abs(x)

#拡張ユークリッドの互除法
def ext_gcd(a: int, b: int) -> tuple[int, int, int]:
    '''
    g == a * x + b * y を満たす(g, x, y)を返します。
    a == b == 0 の場合、(g, x, y) = (0, 1, 0) とします。
    そうでない場合、(g, x, y)は以下の条件を満たします。
    g = gcd(a, b) > 0
    abs(x) <= max(1, abs(b // g))
    abs(y) <= max(1, abs(a // g))
    '''
    if b == 0:
        return (a, 1, 0) if a >= 0 else (- a, - 1, 0)
    g, x, y = ext_gcd(b, a % b)
    return g, y, x - (a // b) * y

#中国剰余定理
def CRT(R: list[int], M: list[int]) -> tuple[int, int]:
    '''
    n ≡ R[i] mod M[i] をすべて満たす非負整数n < lcm(M)を求め、(n, lcm(M))を返します。
    答えがない場合は(0, 0)を、len(R) == len(M) == 0の場合は(0, 1)を返します。
    制約: len(R) == len(M), 0 < M[i], lcm(M) < 2 ** 63
    '''
    assert len(R) == len(M)
    assert all(0 < Mi for Mi in M)
    R1, M1 = 0, 1
    for R2, M2 in zip(R, M):
        R2 %= M2
        if R2 < 0:
            R2 += M2
        if M1 > M2:
            R1, M1, R2, M2 = R2, M2, R1, M1
        f, g, i, j = M1, M2, 1, 0  #g: gcd(M1, M2),  i: invmod(M1 // g, M2 // g)
        while f:
            h = g // f
            f, g, i, j = g - h * f, f, j, i - h * j
        p, q = R1 - R2, M1 // g
        r, s = p // g, p % g
        if s:
            return (0, 0)
        R1, M1 = r * i % q * M2 + R2, M2 * q  #assert abs(r * i) < M2 * q
        if R1 < 0:
            R1 += M1
    return (R1, M1)

#floor sum for codon
#Reference: https://qiita.com/AkariLuminous/items/3e2c80baa6d5e6f3abe9
def floor_sum[T](n: T, m: T, a: T, b: T) -> T:
    '''
    sum( floor( (ai + b) / m ) for i in range(n) ) をO(log m)で求めます。
    制約: 0 < m, 型は整数, ai + bがオーバーフローしない
    '''
    zero: T = n ^ n
    one: T = - ~ zero
    assert zero < m, f'mが正整数ではありません。{m = }'
    if n <= zero:
        return zero
    ans: T = zero
    while True:
        if not zero <= a < m:
            a_div, a = divmod(a, m)
            ans += ( ((n - one) >> one) * n if n & one else (n - one) * (n >> one) ) * a_div
        if not zero <= b < m:
            b_div, b = divmod(b, m)
            ans += n * b_div
        y_max: T = a * n + b
        if y_max < m:
            return ans
        else:
            y_div, y_mod = divmod(y_max, m)
            n, m, a, b = y_div, a, m, y_mod

素因数分解

高速素因数分解 for codon
prime for codon
#高速素因数分解 for codon
#Reference: https://qiita.com/t_fuki/items/7cd50de54d3c5d063b4a
class prime:
    #内部関数
    def _miller_rabin(N: int) -> bool:
        if N < 2 or N & 1 == 0:
            return N == 2
        M, e = N - 1, (N - 1).__cttz__()  #e = (M & - M).bit_length() - 1
        d = M >> e  #M = N - 1 = d << e
        N128, M128 = UInt[128](N), UInt[128](M)
        for a in ([2, 7, 61] if N < 48781 * 97561 else
                  [2, 325, 9375, 28178, 450775, 9780504, 1795265022]):
            if a >= N:
                continue
            c = d
            x128, y128 = UInt[128](1), UInt[128](a)  #x = pow(a, d, N)
            while c:  #x = pow(a, d, N)
                if c & 1:
                    x128 = x128 * y128 % N128
                y128 = y128 * y128 % N128
                c >>= 1
            if x128 == UInt[128](1):  #x = pow(a, d, N) ≡ 1 ならおそらく素数
                continue
            while x128 != M128:  #pow(x, 2 ** (c := e未満), N) ≡ -1 ならおそらく素数
                x128 = x128 * x128 % N128
                c += 1
                if x128 == UInt[128](1) or c == e:
                    return False
        return True
    def _pollard_rho(N: int) -> int:  #Nの素因数を探索  ミラーラビンを参照する
        assert N > 0
        if N & 1 == 0:
            return 2
        if N == 1 or prime._miller_rabin(N):
            return N
        while True:
            N128 = Int[128](N)
            step = int(N ** 0.125) + 1
            for c in range(1, N):
                #f(n) = n ** 2 + c mod N と疑似乱数を定義する
                #y128 = f^{s}(0), z128: Π(x128 - y128) mod N128
                #g: gcd(x, y)  t: sの次の目標となる2冪
                y128, z128, c128 = Int[128](0), Int[128](1), Int[128](c)
                g, s, t = 1, 0, 1
                while g == 1:
                    x128 = y128
                    nxt_s = (3 * t) >> 2
                    for _ in range(nxt_s - s):
                        y128 = (y128 * y128 + c128) % N128  #y ← f(y)
                    s = nxt_s
                    while s < t and g == 1:
                        backtrack128 = y128
                        for _ in range(min(step, t - s)):  #N ** 1/8回まとめてgcdを計算
                            y128 = (y128 * y128 + c128) % N128  #y ← f(y)
                            z128 = z128 * (x128 - y128) % N128
                        g, h = N, abs(int(z128))
                        while h:  #g ← gcd(N, z128)
                            g, h = h, g % h
                        s += step
                    s, t = t, t << 1
                if g == N:
                    g, y128 = 1, backtrack128
                    while g == 1:
                        y128 = (y128 * y128 + c128) % N128  #y ← f(y)
                        g, h = abs(int(x128 - y128)), N
                        while h:  #g ← gcd(N, x128 - y128)
                            g, h = h, g % h
                    if g == N:
                        continue  #検出失敗
                if prime._miller_rabin(g):
                    return g
                elif prime._miller_rabin(N // g):
                    return N // g
                else:
                    N = g
                    break  #while Trueへ
    def _fast_fact(N: int) -> list[tuple[int, int]]:
        assert N >= 1
        ans: list[tuple[int, int]] = []
        if N & 1 == 0:
            ans.append((2, N.__cttz__()))
            N >>= N.__cttz__()
        p2 = 1
        for p in range(3, int(N ** 0.25), 2):  #O(N ** 1/4)回のためし割り
            p2 += (p - 1) << 2  #assert p * p == p2
            if p2 > N:
                if N > 1:
                    ans.append((N, 1))
                    N = 1
                break
            if N % p == 0:
                e = 0
                while N % p == 0:
                    N //= p
                    e += 1
                ans.append((p, e))
        while N > 1:
            p = prime._pollard_rho(N)
            e = 0
            while N % p == 0:
                N //= p
                e += 1
            ans.append((p, e))
        ans.sort()
        return ans
    def _enumerate_divisor(N: int) -> list[int]:
        F: list[tuple[int, int]] = prime._fast_fact(N)
        Rt: int = 1
        for _, e in F:
            Rt *= e + 1
        D: list[int] = [1] * Rt
        Rt: int = 1
        for p, e in F:
            for Lt in range(Rt * e):
                D[Rt] = D[Lt] * p
                Rt += 1
        D.sort()
        return D

    #素数判定
    def is_prime(N: int) -> bool:
        '''
        ミラーラビン素数判定法により素数判定を行います。
        計算量: int128の剰余演算の計算量をO(L)としたとき、O(7L * logN)
        制約: 1 <= N < 2 ** 63
        '''
        assert 1 <= N
        return prime._miller_rabin(N)

    #O(N ** 1/4) 高速素因数分解
    def factorize(N: int) -> list[tuple[int, int]]:
        '''
        Nを素因数分解し、(素因数, 次数) の形のリストとして返します。
        期待計算量: int128の剰余演算をO(L)としたとき、O(L * N ** 1/4)
        制約: 1 <= N < 2 ** 63
        '''
        assert 1 <= N
        return prime._fast_fact(N)

    #約数列挙
    def divisor(N: int) -> list[int]:
        '''
        Nの約数を列挙し、ソートして返します。
        期待計算量: 約数の個数をdとしたとき、prime.factorize + O(d * logd)
        制約: 1 <= N < 2 ** 63
        '''
        assert 1 <= N
        return prime._enumerate_divisor(N)

SortedSet・SortedList

tatyamさんの実装のほうが優秀なので、そちらを利用してください。

SkipListを用いた期待$O(logN)$の実装をしてみたのですが、遅すぎてだめでした
内部実装もぐちゃくちゃなので手の施しようがありません。供養として置いておきます

SortedSet for codon

SortedList for codon

Wavelet Matrix

Wavelet Matrix for codon・PyPy3
Wavelet Matrix for codon・PyPy3
#Wavelet Matrix for codon, PyPy3
import heapq as _WM_heapq
class WaveletMatrix:
    '''
    Wavelet Matrix for codon, PyPy3
    非負整数列Aに対する検索を対数時間で行います。
    bit_countの実装上、すべての操作に O(log wordsize) = O(3) の項がかかります。
    
    N = len(A), M = max(A) として
    構築: 時間 O(N logM), 空間O(N)
    検索: 時間 O(logM) ~

    A: 読み込ませたい非負整数列
    '''
    _N: int
    _logM: int
    _size: int
    _C: list[int]
    _D: list[int]
    _zero: list[int]
    _one: list[int]
    _stack: list[int]
    __slots__ = ('_N', '_logM', '_size', '_C', '_D', '_zero', '_one', '_stack')
    def __init__(self, A: list[int]) -> None:
        assert len(A) == 0 or min(A) >= 0, f'Aに負の要素が含まれます。{min(A) = }'
        assert len(A) < 2 ** 29, f'len(A)が長すぎます。 {len(A) = }'
        self._N = N = len(A)
        maxA: int = 0 if len(A) == 0 else max(A)
        self._logM = logM = 0 if maxA == 0 else len(bin(maxA)) - 2
        self._size = size = -(- N >> 5)
        self._C = C = [0] * size * logM  #FIDをlogM個作成
        self._zero: list[int] = [0] * logM
        self._one: list[int] = [0] * logM
        self._stack: list[int] = [0] * ((logM + 1) << 1)
        D: list[int] = list(range(N))
        E: list[int] = [0] * N  #DとEをswapしながら上の桁から決定
        for k in range(logM - 1, -1, -1):
            offset: int = size * k
            zero = one = now = 0
            for b in range(offset, offset + size):
                C[b] = one << 32
                for c in range(32):
                    if now >= N:
                        break
                    v: int = A[D[now]] >> k & 1
                    if v == 0:
                        zero += 1
                    else:
                        one += 1
                        C[b] |= 1 << c
                    now += 1
            Lt, Rt = 0, zero
            for D_now in D:
                if A[D_now] >> k & 1 == 0:
                    E[Lt] = D_now
                    Lt += 1
                else:
                    E[Rt] = D_now
                    Rt += 1
            self._zero[k], self._one[k] = zero, one
            D, E = E, D
        self._D: list[int] = D
    #内部関数: FID
    def _FID_access(self, k: int, i: int) -> int:  #FID[k]に対し、B[i] >> k & 1
        return self._C[self._size * k + (i >> 5)] >> (i & 31) & 1
    def _FID_rank(self, k: int, i: int, num: int) -> int:  #[0, i]のnumの個数
        if i < 0:
            return 0
        Ci: int = self._C[self._size * k + (i >> 5)]
        n: int = Ci & ~(-1 << ((i & 31) + 1))  #one = (Ci >> 32) + n.bit_count()
        n: int = ( n & 0x55555555 ) + ( (n >> 1) & 0x55555555 )
        n: int = ( n & 0x33333333 ) + ( (n >> 2) & 0x33333333 )
        n: int = ( n & 0x0F0F0F0F ) + ( (n >> 4) & 0x0F0F0F0F )
        one: int = (Ci >> 32) + (n * 0x1010101 >> 24 & 63)
        return one if num == 1 else 1 + i - one
    def _FID_stable_sort(self, k: int, i: int) -> int:  #安定ソート後のiの位置
        num: int = self._C[self._size * k + (i >> 5)] >> (i & 31) & 1  #access(i)
        if num == 0:
            return self._FID_rank(k, i - 1, 0)
        else:
            return self._FID_rank(k, i - 1, 1) + self._zero[k]
    def _FID_range_sort(self, k: int, Lt: int, Rt: int, num: int) -> tuple[int, int]:
        offset: int = 0 if num == 0 else self._zero[k]
        return (offset + self._FID_rank(k, Lt - 1, num),
                offset + self._FID_rank(k, Rt - 1, num))

    #基本機能: 計算量がO(logM)
    def access(self, i: int) -> int:
        'A[i]をO(logM)で取得します。'
        if i < 0:
            i += self._N
        assert 0 <= i < self._N
        ans: int = 0
        for k in range(self._logM - 1, -1, -1):
            b: int = self._FID_access(k, i)
            ans |= b << k
            i: int = self._FID_stable_sort(k, i)
        return ans
    def rank(self, Lt: int, Rt: int, value: int) -> int:
        'A[Lt, Rt)のvalueの出現回数をO(logM)で取得します。'
        assert 0 <= Lt <= Rt <= self._N
        if value < 0 or value >> self._logM >= 1:
            return 0
        for k in range(self._logM - 1, -1, -1):
            Lt, Rt = self._FID_range_sort(k, Lt, Rt, value >> k & 1)
        return Rt - Lt
    def select(self, cnt: int, value: int) -> int:
        '''
        0-indexedでcnt個目のvalueの添字をO(logM)で取得します。
        特に、cnt = 0 かつ value in A の時は A[Lt: Rt].index(value) と返り値が一致します。
        cnt >= A.count(value) の場合、Nを返します。
        '''
        if value < 0 or value >> self._logM >= 1:
            return self._N
        Lt, Rt = 0, self._N
        for k in range(self._logM - 1, -1, -1):
            Lt, Rt = self._FID_range_sort(k, Lt, Rt, value >> k & 1)
        if cnt >= Rt - Lt:
            return self._N
        else:
            return self._D[Lt + cnt]
    def kth_min(self, Lt: int, Rt: int, k: int) -> int:
        'sorted( A[Lt, Rt) )[k] : A[Lt, Rt)の小さい側からk番目の要素 をO(logM)で取得します。'
        assert 0 <= Lt <= Rt <= self._N
        assert 0 <= k < Rt - Lt, f'k is out of range: {Rt - Lt = }, {k = }'
        cnt: int = k  #内部的に添字をk → cntに変更
        ans: int = 0
        for k in range(self._logM - 1, -1, -1):
            Li: int = self._FID_rank(k, Lt - 1, 0)
            Ri: int = self._FID_rank(k, Rt - 1, 0)
            zero: int = Ri - Li
            if cnt < zero:
                Lt, Rt = Li, Ri
            else:  #Lt, Rt = self._FID_range_sort(k, Lt, Rt, 1)
                ans |= 1 << k
                cnt -= zero
                offset: int = self._zero[k]
                Lt, Rt = offset + (Lt - Li), offset + (Rt - Ri)
        return ans
    def range_freq(self, Lt: int, Rt: int, vL: int, vR: int) -> int:
        'A[Lt, Rt)に存在する、 vL <= Ai < vR を満たすAiの個数をO(logM)で取得します。'
        assert 0 <= Lt <= Rt <= self._N
        if not vL < vR:
            return 0
        ans: int = Rt - Lt
        stack: list[int] = self._stack
        if vL > 0:
            stack[0], stack[1] = 0, self._logM << 58 | Lt << 29 | Rt
            d: int = 2
            while d:
                c, x = stack[d - 2], stack[d - 1]
                d -= 2
                k, Li, Ri = x >> 58, x >> 29 & 0x1FFFFFFF, x & 0x1FFFFFFF
                if c + (1 << k) <= vL:
                    ans -= Ri - Li
                    continue
                k -= 1
                if k == -1:
                    break
                for b in (1, 0):
                    Lj, Rj = self._FID_range_sort(k, Li, Ri, b)
                    if Lj != Rj:
                        stack[d], stack[d + 1] = c | b << k, k << 58 | Lj << 29 | Rj
                        d += 2
        if vR <= ~(-1 << self._logM):
            stack[0], stack[1] = 0, self._logM << 58 | Lt << 29 | Rt
            d: int = 2
            while d:
                c, x = stack[d - 2], stack[d - 1]
                d -= 2
                k, Li, Ri = x >> 58, x >> 29 & 0x1FFFFFFF, x & 0x1FFFFFFF
                if c >= vR:  #変更点
                    ans -= Ri - Li
                    continue
                k -= 1
                if k == -1:
                    break
                for b in (0, 1):  #変更点
                    Lj, Rj = self._FID_range_sort(k, Li, Ri, b)
                    if Lj != Rj:
                        stack[d], stack[d + 1] = c | b << k, k << 58 | Lj << 29 | Rj
                        d += 2   
        return ans
    def prev_value(self, Lt: int, Rt: int, value: int) -> int:
        '''
        A[Lt, Rt)のうち、valueより真に小さい最大値をO(logM)で取得します。
        そのような値が存在しない場合、-1を返します。
        '''
        assert 0 <= Lt <= Rt <= self._N
        if Lt == Rt or value <= 0:
            return -1
        stack: list[int] = self._stack
        stack[0], stack[1] = 0, self._logM << 58 | Lt << 29 | Rt
        d: int = 2
        while d:
            c, x = stack[d - 2], stack[d - 1]
            d -= 2
            k, Li, Ri = x >> 58, x >> 29 & 0x1FFFFFFF, x & 0x1FFFFFFF
            if c >= value:  #変更点
                continue
            k -= 1
            if k == -1:
                return c
            for b in (0, 1):  #変更点
                Lj, Rj = self._FID_range_sort(k, Li, Ri, b)
                if Lj != Rj:
                    stack[d], stack[d + 1] = c | b << k, k << 58 | Lj << 29 | Rj
                    d += 2
        else:
            return -1
    def next_value(self, Lt: int, Rt: int, value: int) -> int:
        '''
        A[Lt, Rt)のうち、valueより真に大きい最小値をO(logM)で取得します。
        そのような値が存在しない場合、-1を返します。
        '''
        assert 0 <= Lt <= Rt <= self._N
        if Lt == Rt or value >= ~(-1 << self._logM):
            return -1
        value += 1
        stack: list[int] = self._stack
        stack[0], stack[1] = 0, self._logM << 58 | Lt << 29 | Rt
        d: int = 2
        while d:
            c, x = stack[d - 2], stack[d - 1]
            d -= 2
            k, Li, Ri = x >> 58, x >> 29 & 0x1FFFFFFF, x & 0x1FFFFFFF
            if c + (1 << k) <= value:
                continue
            k -= 1
            if k == -1:
                return c
            for b in (1, 0):
                Lj, Rj = self._FID_range_sort(k, Li, Ri, b)
                if Lj != Rj:
                    stack[d], stack[d + 1] = c | b << k, k << 58 | Lj << 29 | Rj
                    d += 2
        else:
            return -1
        
    #基本機能: 計算量がO(logM)でないもの
    def topk_mode(self, Lt: int, Rt: int, k: int) -> list[tuple[int, int]]:
        '''
        A[Lt, Rt)の頻度を数え、E: [(値, 個数) のリスト]を作成します。
        その後 個数の降順・同率なら値の昇順 にEをソートし、E[:cnt]を返します。
        計算量は O( cnt * logNlogM ) です。返り値のタプルの順序は(値, 個数)です。
        '''
        assert 0 <= Lt <= Rt <= self._N
        ans: list[tuple[int, int]] = []
        cnt: int = k  #内部的に添字をk → cntに変更
        if Lt == Rt or cnt <= 0:
            return ans
        Q: list[tuple[int, int, int]] = [
            ( ~( (Rt - Lt) << 32 | self._logM), 0, Lt << 32 | Rt )]
        while Q and len(ans) < cnt:
            x, y, z = _WM_heapq.heappop(Q)
            w, k = (~ x) >> 32, ((~ x) & 0xFFFFFFFF) - 1
            Li, Ri = z >> 32, z & 0xFFFFFFFF
            assert Li <= Ri and 0 <= w == Ri - Li and k >= -1
            if k == -1:
                ans.append((y, Ri - Li))
                continue
            for b in range(2):
                Lj, Rj = self._FID_range_sort(k, Li, Ri, b)
                if Rj > Lj:
                    _WM_heapq.heappush(
                        Q, (~ ((Rj - Lj) << 32 | k), y | b << k, Lj << 32 | Rj ))
        return ans
    def intersect(self, L1: int, R1: int, L2: int, R2: int) -> list[tuple[int, int]]:
        '''
        A[L1, R1) と A[L2, R2) の共通要素を取り出し、(値, 個数) の昇順で返します。
        計算量は O( (R - L)logM )です。
        '''
        assert 0 <= L1 <= R1 <= self._N and 0 <= L2 <= R2 <= self._N
        ans: list[tuple[int, int]] = []
        if L1 == R1 or L2 == R2:
            return ans
        stack: list[int] = self._stack
        while len(stack) < (self._logM + 1) * 3:
            stack.append(0)
        stack[0], stack[1], stack[2] = 0, self._logM << 58 | L1 << 29 | R1, L2 << 29 | R2
        d: int = 3
        while d:
            c, y, z = stack[d - 3], stack[d - 2], stack[d - 1]
            d -= 3
            k, L1i, R1i = (y >> 58) - 1, y >> 29 & 0x1FFFFFFF, y & 0x1FFFFFFF
            L2i, R2i = z >> 29 & 0x1FFFFFFF, z & 0x1FFFFFFF
            if k == -1:
                ans.append((c, min(R1i - L1i, R2i - L2i)))
                continue
            for b in (1, 0):
                L1j, R1j = self._FID_range_sort(k, L1i, R1i, b)
                L2j, R2j = self._FID_range_sort(k, L2i, R2i, b)
                if L1j != R1j and L2j != R2j:
                    stack[d] = c | b << k
                    stack[d + 1] = k << 58 | L1j << 29 | R1j
                    stack[d + 2] = L2j << 29 | R2j
                    d += 3
        return ans

行列累乗

行列累乗 for codon, PyPy3
行列累乗
#行列累乗 for codon, PyPy3
class matrix_power:
    '''
    行列累乗 for codon, PyPy3
    法MODの下、行列計算を行います。
    行列は2次元リストとして渡してください。
    返り値は新しい2次元リストで、全成分が0以上MOD未満を満たします。

    MOD: 法
    '''
    MOD: int
    _acc_limit: int
    __slots__ = ('MOD', '_acc_limit')
    def __init__(self, MOD: int) -> None:
        self.MOD = MOD
        if MOD > 3037000500:  #3_037_ + 3_037 ** 2 >= 2 ** 63
            self._acc_limit: int = 1
        else:  #(MOD - 1) + _acc_limit * (MOD - 1) ** 2 < 2 ** 63
            self._acc_limit: int = (~(-1 << 63) - (MOD - 1)) // ((MOD - 1) ** 2)
    #内部関数
    def _matrix_add(self, A: list[list[int]], B: list[list[int]],
                    H: int, W: int) -> list[list[int]]:
        C: list[list[int]] = [[0] * W for _ in range(H)]
        for h in range(H):
            Ah, Bh, Ch = A[h], B[h], C[h]
            for w, Ahw in enumerate(Ah):
                Ch[w] = (Ahw % self.MOD) + (Bh[w] % self.MOD)
                if Ch[w] >= self.MOD:
                    Ch[w] -= self.MOD
                while Ch[w] < 0:
                    Ch[w] += self.MOD
        return C
    def _matrix_mul(self, A: list[list[int]], B: list[list[int]], C: list[list[int]],
                    H: int, W: int, X: int) -> list[list[int]]:
        #すべての成分が 0 <= A[h][x], B[x][w] < MOD を満たすことを要求する
        B_w: list[int] = [0] * X
        for w in range(W):  #C[h][w] = sum(A[h][x] * B[x][w] for all x)
            for x in range(X):
                B_w[x] = B[x][w]
            for h in range(H):
                cnt: int = 0
                k: int = self._acc_limit
                for x, Ahx in enumerate(A[h]):
                    cnt += Ahx * B_w[x]
                    k -= 1
                    if k == 0:
                        cnt %= self.MOD
                        k = self._acc_limit
                C[h][w] = cnt % self.MOD
        return C
    def _matrix_doubling_mul(self, A: list[list[int]], N: int, k: int) -> list[list[int]]:
        if k == 0:
            return self.eye(N)
        C: list[list[int]] = [[0] * N for _ in range(N)]
        for h in range(N):
            Ch: list[int] = C[h]
            for w, Ahw in enumerate(A[h]):
                Ch[w] = Ahw % self.MOD
                if Ch[w] < 0:
                    Ch[w] += self.MOD
        if k == 1:
            return C
        D: list[list[int]] = [[0] * N for _ in range(N)]
        E: list[list[int]] = [C[h][:] for h in range(N)]
        for i in range(len(bin(k)) - 4, -1, -1):
            self._matrix_mul(C, C, D, N, N, N)
            C, D = D, C
            if k >> i & 1 == 1:
                self._matrix_mul(C, E, D, N, N, N)
                C, D = D, C
        return C
        
    #基本機能
    def eye(self, N: int) -> list[list[int]]:
        'N行N列の単位行列を返します。'
        A: list[list[int]] = [[0] * N for _ in range(N)]
        for i in range(N):
            A[i][i] = 1
        return A
    def add(self, A: list[list[int]], B: list[list[int]]) -> list[list[int]]:
        '行列C := A + B を新しく生成します。'
        assert len(A) == len(B)
        if len(A) == 0:
            return []
        H: int = len(A)
        W: int = len(A[0])
        assert all(len(Ai) == W for Ai in A) and all(len(Bi) == W for Bi in B)
        return self._matrix_add(A, B, H, W)
    def mul(self, A: list[list[int]], B: list[list[int]]) -> list[list[int]]:
        'H行X列の行列Aと、X行W列の行列Bから、行列C := A * Bを新しく作成します。'
        H: int = len(A)
        if H == 0:
            return []
        X: int = len(A[0])
        assert all(len(Ai) == X for Ai in A) and len(B) == X
        if X == 0:
            return [[] for _ in range(H)]
        W: int = len(B[0])
        assert all(len(Bi) == W for Bi in B)
        new_A: list[list[int]] = [[0] * X for _ in range(H)]
        new_B: list[list[int]] = [[0] * W for _ in range(X)]
        for h in range(H):
            new_Ah: list[int] = new_A[h]
            for x, Ahx in enumerate(A[h]):
                new_Ah[x] = Ahx % self.MOD
                if new_Ah[x] < 0:
                    new_Ah[x] += self.MOD
        for x in range(X):
            new_Bx: list[int] = new_B[x]
            for w, Bxw in enumerate(B[x]):
                new_Bx[w] = Bxw % self.MOD
                if new_Bx[w] < 0:
                    new_Bx[w] += self.MOD
        return self._matrix_mul(new_A, new_B, [[0] * W for _ in range(H)], H, W, X)
    def doubling_mul(self, A: list[list[int]], k: int) -> list[list[int]]:
        '正方行列Aから、正方行列C := A ** k を新しく作成します。'
        N: int = len(A)
        if N == 0:
            return []
        assert all(len(Ai) == N for Ai in A)
        assert k >= 0
        return self._matrix_doubling_mul(A, N, k)

おわりに

おわりです。
codonの開拓が進んだらうれしいです。

4
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?