0
0

More than 1 year has passed since last update.

Pythonで値オブジェクトっぽいModint

Last updated at Posted at 2023-04-19

こだわりポイント

  • Immutableいいよね
  • modが異なるModint同士の演算を防止
  • 制約や規則は極力クラス外に定義、__truediv__だけ諦めた
  • 競プロ目的ではないため速度には無頓着

コード

modint.py
from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass
from functools import wraps
from typing import Self, SupportsInt, TypeAlias

from sympy import mod_inverse as _mod_inverse

# NOTE mod 素数 だけでいいなら:
# from functools import lru_cache
# @lru_cache(maxsize=None)
# def _isprime(n: int):
#     if n in {2, 3}:
#         return True
#     if n % 2 == 0 or n < 2:
#         return False
#     return all(n % i != 0 for i in range(3, int(n**0.5) + 1, 2))

# def _mod_inverse(a: int, m: int):
#     if _isprime(m):
#         return pow(a, m - 2, m)
#     raise ValueError(f"{m} is not prime")

_ModintMethod: TypeAlias = Callable[["Modint", "Modint | SupportsInt"], "Modint"]


def _ensure_mod_equals(func: _ModintMethod) -> _ModintMethod:
    """dunder-method で self.mod と other.mod が一致することを保証する"""

    @wraps(func)
    def inner(self: Modint, other: Modint | SupportsInt) -> Modint:
        if isinstance(other, Modint) and self.mod != other.mod:
            msg = f"{self.mod=} != {other.mod=}"
            raise ValueError(msg)
        return func(self, other)

    return inner


@dataclass(frozen=True, slots=True)
class Modint:
    num: int
    mod: int

    # NOTE init, repr, eq は @dataclass が実装するものに任せる
    # NOTE eq と frozen のため、hashable になる

    @property
    def inverse(self) -> int:
        # NOTE `cached_property` では `__slots__` を使えない
        # 内部の関数を `lru_cache` すれば使える
        # `sympy.isprime` は篩を共有しているから `lru_cache` を付けるのはナンセンス
        return _mod_inverse(self.num, self.mod)

    def __post_init__(self) -> None:
        # NOTE `frozen=True` で引数の加工をする場合、`object.__setattr__` 経由でバイパスする
        if not (0 <= self.num < self.mod):
            object.__setattr__(self, "num", self.num % self.mod)

    def __int__(self) -> int:
        return self.num

    def __str__(self) -> str:
        return str(self.num)

    @_ensure_mod_equals
    def __add__(self, other: Self | SupportsInt) -> Self:
        return self.__class__(int(self) + int(other), self.mod)

    @_ensure_mod_equals
    def __sub__(self, other: Self | SupportsInt) -> Self:
        return self.__class__(int(self) - int(other), self.mod)

    @_ensure_mod_equals
    def __mul__(self, other: Self | SupportsInt) -> Self:
        return self.__class__(int(self) * int(other), self.mod)

    @_ensure_mod_equals
    def __truediv__(self, other: Self | SupportsInt) -> Self:
        inv = (
            other.inverse
            if isinstance(other, self.__class__)
            else _mod_inverse(int(other), self.mod)
        )
        return self.__class__(int(self) * inv, self.mod)

    @_ensure_mod_equals
    def __pow__(self, power: Self | SupportsInt) -> Self:
        return self.__class__(pow(int(self), int(power), self.mod), self.mod)

    @_ensure_mod_equals
    def __rsub__(self, other: Self | SupportsInt) -> Self:
        return self.__class__(-int(self) + int(other), self.mod)

    @_ensure_mod_equals
    def __rtruediv__(self, other: Self | SupportsInt) -> Self:
        return self.__class__(self.inverse * int(other), self.mod)

    __radd__ = __add__
    __rmul__ = __mul__
    __floordiv__ = __truediv__
    __rfloordiv__ = __rtruediv__
test_modint.py
from dataclasses import FrozenInstanceError

import pytest
from modint import Modint


def test_normal():
    x = Modint(10, 998244353)
    y = Modint(7, 998244353)

    assert int(x + y) == 17
    assert int(x - y) == 3
    assert int(x * y) == 70
    assert int(x / y) == 570425346
    assert int(x // y) == 570425346
    assert int(y**5) == 16807
    assert x != y

    x += y
    assert int(x) == 17
    x -= y
    assert int(x) == 10
    x *= y
    assert int(x) == 70
    x /= y
    assert int(x) == 10
    x //= y
    assert int(x) == 570425346

    assert int(x + 5) == 570425351
    assert int(22 + y) == 29

    assert x.inverse == 99824436


def test_not_same_mod():
    with pytest.raises(ValueError):
        _ = Modint(1, 7) + Modint(1, 3)


def test_frozen():
    x = Modint(1, 7)

    with pytest.raises(FrozenInstanceError):
        x.num = 42  # type: ignore

    with pytest.raises(FrozenInstanceError):
        x.mod = 42  # type: ignore

参考文献

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0