はじめに
AtCoder Library のModint
が個人的に使いづらいので自作しました。
方針
- 基本的には AtCoder Library とインターフェースを合わせるが、一部で Python 標準の演算子を用いる。
x.val()
はint(x)
で、x.inv()
はx ** -1
で、raw(x)
はModint(x)
で代用する - mod の指定は大域変数で行う。たいてい 1 つしか使わないので、AtCoder Library のように
ModContext
で指定するのは面倒 - (元のC++の) AtCoder Library と同様に、加減乗除の片方が
int
でも動作可能とする - イミュータブルなオブジェクトにする
- 呼び出される回数が多いはずなので、分岐は避ける
コード
MOD = 1000000007
class Modint:
def __init__(self, v = 0):
self._v = v % MOD
def __add__(self, rhs):
return Modint(self._v + int(rhs))
def __sub__(self, rhs):
return Modint(self._v - int(rhs))
def __mul__(self, rhs):
return Modint(self._v * int(rhs))
def __floordiv__(self, rhs):
return Modint(self._v * pow(int(rhs), -1, MOD))
def __pow__(self, n):
return Modint(pow(self._v, n, MOD))
def __radd__(self, rhs):
return Modint(rhs + self._v)
def __rsub__(self, rhs):
return Modint(rhs - self._v)
def __rmul__(self, rhs):
return Modint(rhs * self._v)
def __rfloordiv__(self, rhs):
return Modint(rhs * pow(self._v, -1, MOD))
def __eq__(self, rhs):
return rhs == self._v
def __neg__(self):
return Modint(-self._v)
def __pos__(self):
return self
def __str__(self):
return str(self._v)
def __repr__(self):
return f'Modint({self._v!r})'
def __int__(self):
return self._v
def __hash__(self):
return hash(self._v)
def __bool__(self):
return bool(self._v)
実装について
末尾の参考記事とかぶらない内容のみ説明します。
__iadd__ など
イミュータブルなオブジェクトにしたいなら、実装してはいけません。定義しなくても+=
などは動作しますし、定義してしまうと意図しない値の共有が起こって混乱することになります。
新規のオブジェクトを返すように実装すればイミュータブルにはなりますが、公式の説明に「 These methods should attempt to do the operation in-place (modifying self)」とある通り、実装するなら自身の値を上書きすべきです。
__eq__
__add__
などと異なりint(rhs)
を使用していないのは、無関係のオブジェクトと比較してもエラーにならないようにするためです。
相手がint
の場合に剰余をとらないのは、上記エラー対策に加えて、比較を推移的(公式の説明も参照)にするという意味もあります。もし剰余を取ってから比較する仕様にすると、MOD=11
の場合に2 == Modint(2) == 13
となり、推移的だとすると2 == 13
という誤った結果が導かれてしまいます。
!=
については、__ne__
が定義されていなければ__eq__
の否定として定義されるので、明示的に定義する必要はありません。
__neg__, __pos__
それぞれ、単項マイナス、単項プラスの実装となります。AtCoder Library で実装されていたので追加しましたが、不要なら削除しても構いません。
__repr__
公式の説明に従い「 (適切な環境が与えられれば) 同じ値のオブジェクトを再生成するのに使える、有効な Python 式のようなもの」を返すようにしています。
__hash__
Modint
をdict
のキーにしたいなら実装が必要です。そのようなケースは少ないと思われるので、不要なら削除しても構いません。
__bool__
こちらを定義しないと、Modint(0)
がTrue
扱いとなってしまいます。それで困るケースは少ないと思われるので、不要なら削除しても構いません。
余談(こちらのコードを書いた動機)
競プロ典型90問の008について、
AtCoder Library の Modint を使って以下のようなコードを作成したところ(コードが拙いのは無視してください)、入力例2にて16
という誤った値が出力されました。
from collections import defaultdict
from atcoder.modint import ModContext, raw
from itertools import pairwise
n = int(input())
s = defaultdict(list)
with ModContext(1_000_000_007):
for i, c in enumerate(input()):
s[c].append([i, 0])
for c in s:
s[c] = s[c][::-1]
for v in s['r']:
v[1] = raw(1)
for c, d in pairwise('atcoder'[::-1]):
i, o = 0, raw(0)
for v in s[d]:
while i < len(s[c]) and v[0] <= s[c][i][0]:
o += s[c][i][1]
i += 1
v[1] = o
print(sum((i[1] for i in s['a']), start=raw(0)).val())
調べたところ、17行目のo += s[c][i][1]
で意図せず値が共有されてしまい、関係ない値まで更新されていたことが原因でした。o = o + s[c][i][1]
に変更することで正答となりました。
既にこの世には Python 用の Modint 実装があふれているのは重々承知していますが、 この怒りをぶつけるため ミュータブルになっていたり、実装が複雑だったりで、自分の望むものが見つからなかったため、あえて自作することとしました。
最終的に、先頭のModint
の実装 + 以下のコードで正答となりました。めでたしめでたし。
from collections import defaultdict
from itertools import pairwise
n = int(input())
s = defaultdict(list)
for i, c in enumerate(input()):
s[c].append([i, 0])
for c in s:
s[c] = s[c][::-1]
for v in s['r']:
v[1] = Modint(1)
for c, d in pairwise('atcoder'[::-1]):
i = o = 0
for v in s[d]:
while i < len(s[c]) and v[0] <= s[c][i][0]:
o += s[c][i][1]
i += 1
v[1] = o
print(sum(i[1] for i in s['a']))
参考
Python で Modint を使いたい場合、AtCoder なら公式ライブラリがありますし、それで不足する場合でも Qiita 上にいくらでも記事があるので、気に入ったものを選べばよいかと思います。