はじめに
競技プログラミング界隈で用いられている modint (自動で剰余演算される整数) を Fortran で実装しました。実装は AtCoder の解説放送 (ABC 第 129, 130, 133 回) (リンク先は YouTube の動画なので音量に注意してください) を参考にしました。
Fortran で実装した理由としては、 Fortran での実装例が見当たらなかったのと、最近触り始めた Go 言語で modint を実装しようとした際に、 Go 言語が演算子のオーバーロードをサポートしていない事実に気がついたためです。
環境
macOS Sierra (10.12.6)
GNU Fortran (GCC) 6.3.0
プログラム
実際に実装したものが以下のコードです。通常の integer との互換性を持たせるために少々長くなってしまいましたが、演算子をオーバーロードしているおかげで integer と同じ感覚で演算可能です。
module mod_modint
implicit none
integer(8) :: modulus = 1000000007_8
! integer(8) :: modulus = 998244353_8
! integer(8) :: modulus = 1000000009_8
type modint
private
integer(8) :: num
contains
procedure :: get => getnum
end type
interface change_modulus
module procedure :: change_modulus64, change_modulus32
end interface change_modulus
interface modint
module procedure :: newm, newi
end interface modint
interface assignment(=)
module procedure :: setm, seti64, seti32
end interface assignment(=)
interface operator(+)
module procedure :: posm, addmm, addmi64, addmi32, addi64m, addi32m
end interface operator(+)
interface operator(-)
module procedure :: negm, submm, submi64, submi32, subi64m, subi32m
end interface operator(-)
interface operator(*)
module procedure :: mulmm, mulmi64, mulmi32, muli64m, muli32m
end interface operator(*)
interface operator(/)
module procedure :: divmm, divmi64, divmi32, divi64m, divi32m
end interface operator(/)
interface operator(**)
module procedure :: powmi64, powmi32
end interface operator(**)
interface inv
module procedure :: invm
end interface inv
contains
subroutine change_modulus64(newmod)
integer(8), intent(in) :: newmod
if (newmod == 0_8) then
write(*,'(a)') "Error: Invalid value (newmod == 0). (modint, change_modulus64)"
stop
end if
modulus = newmod
end
subroutine change_modulus32(newmod)
integer, intent(in) :: newmod
if (newmod == 0) then
write(*,'(a)') "Error: Invalid value (newmod == 0). (modint, change_modulus32)"
stop
end if
modulus = int(newmod,8)
end
integer(8) function getnum(this)
class(modint), intent(in) :: this
getnum = this%num
end
pure elemental type(modint) function newm()
newm%num = 0_8
end
pure elemental subroutine setm(x,y)
type(modint), intent(inout) :: x
type(modint), intent(in) :: y
x%num = y%num
end
pure elemental function posm(x) result(n)
type(modint), intent(in) :: x
type(modint) :: n
n%num = x%num
end
pure elemental function negm(x) result(n)
type(modint), intent(in) :: x
type(modint) :: n
n%num = modulus-x%num
end
pure elemental function addmm(x,y) result(n)
type(modint), intent(in) :: x, y
type(modint) :: n
n%num = x%num+y%num
if (n%num >= modulus) n%num = n%num-modulus
end
pure elemental function submm(x,y) result(n)
type(modint), intent(in) :: x, y
type(modint) :: n
n%num = x%num-y%num
if (n%num < 0_8) n%num = n%num+modulus
end
pure elemental function mulmm(x,y) result(n)
type(modint), intent(in) :: x, y
type(modint) :: n
n%num = mod(x%num*y%num,modulus)
end
impure elemental function invm(x) result(n)
type(modint), intent(in) :: x
type(modint) :: n
integer(8) :: a, b, c, q, r, v
a = x%num
if (a == 0_8) then
write(*,'(a)') "Error: Division by zero (x == 0). (modint, invm)"
stop
end if
b = modulus
c = 0_8
v = 1_8
do while (b /= 0_8)
q = a/b
r = mod(a,b)
a = b
b = r
r = c
c = v-c*q
v = r
end do
n%num = mod(v,modulus)
if (n%num < 0_8) n%num = n%num+modulus
end
impure elemental function divmm(x,y) result(n)
type(modint), intent(in) :: x, y
type(modint) :: n
n = mulmm(x,invm(y))
end
!##########################################################################
!##################### overload with (normal) integer #####################
!##########################################################################
impure elemental type(modint) function newi(i)
class(*), intent(in) :: i
select type(i)
type is (integer(8))
newi%num = i
type is (integer)
newi%num = int(i,8)
type is (integer(2))
newi%num = int(i,8)
type is (integer(1))
newi%num = int(i,8)
class default
write(*,'(a)') "Error: Invalid value (i is not integer). (modint, newi)"
stop
end select
newi%num = mod(newi%num,modulus)
if (newi%num < 0_8) newi%num = newi%num+modulus
end
impure elemental subroutine seti64(x,i)
type(modint), intent(inout) :: x
integer(8), intent(in) :: i
call setm(x,newi(i))
end
impure elemental subroutine seti32(x,i)
type(modint), intent(inout) :: x
integer, intent(in) :: i
call setm(x,newi(i))
end
impure elemental function addmi64(x,i) result(n)
type(modint), intent(in) :: x
integer(8), intent(in) :: i
type(modint) :: n
n = addmm(x,newi(i))
end
impure elemental function addmi32(x,i) result(n)
type(modint), intent(in) :: x
integer, intent(in) :: i
type(modint) :: n
n = addmm(x,newi(i))
end
impure elemental function addi64m(i,y) result(n)
integer(8), intent(in) :: i
type(modint), intent(in) :: y
type(modint) :: n
n = addmm(newi(i),y)
end
impure elemental function addi32m(i,y) result(n)
integer, intent(in) :: i
type(modint), intent(in) :: y
type(modint) :: n
n = addmm(newi(i),y)
end
impure elemental function submi64(x,i) result(n)
type(modint), intent(in) :: x
integer(8), intent(in) :: i
type(modint) :: n
n = submm(x,newi(i))
end
impure elemental function submi32(x,i) result(n)
type(modint), intent(in) :: x
integer, intent(in) :: i
type(modint) :: n
n = submm(x,newi(i))
end
impure elemental function subi64m(i,y) result(n)
integer(8), intent(in) :: i
type(modint), intent(in) :: y
type(modint) :: n
n = submm(newi(i),y)
end
impure elemental function subi32m(i,y) result(n)
integer, intent(in) :: i
type(modint), intent(in) :: y
type(modint) :: n
n = submm(newi(i),y)
end
impure elemental function mulmi64(x,i) result(n)
type(modint), intent(in) :: x
integer(8), intent(in) :: i
type(modint) :: n
n = mulmm(x,newi(i))
end
impure elemental function mulmi32(x,i) result(n)
type(modint), intent(in) :: x
integer, intent(in) :: i
type(modint) :: n
n = mulmm(x,newi(i))
end
impure elemental function muli64m(i,y) result(n)
integer(8), intent(in) :: i
type(modint), intent(in) :: y
type(modint) :: n
n = mulmm(newi(i),y)
end
impure elemental function muli32m(i,y) result(n)
integer, intent(in) :: i
type(modint), intent(in) :: y
type(modint) :: n
n = mulmm(newi(i),y)
end
impure elemental function divmi64(x,i) result(n)
type(modint), intent(in) :: x
integer(8), intent(in) :: i
type(modint) :: n
n = divmm(x,newi(i))
end
impure elemental function divmi32(x,i) result(n)
type(modint), intent(in) :: x
integer, intent(in) :: i
type(modint) :: n
n = divmm(x,newi(i))
end
impure elemental function divi64m(i,y) result(n)
integer(8), intent(in) :: i
type(modint), intent(in) :: y
type(modint) :: n
n = divmm(newi(i),y)
end
impure elemental function divi32m(i,y) result(n)
integer, intent(in) :: i
type(modint), intent(in) :: y
type(modint) :: n
n = divmm(newi(i),y)
end
impure elemental function powmi64(x,i) result(n)
type(modint), intent(in) :: x
integer(8), intent(in) :: i
type(modint) :: n, p
integer(8) :: m
n = newi(1_8)
p = x
m = i
if (i < 0_8) then
p = invm(x)
m = abs(i)
end if
do while (m > 0_8)
if (btest(m,0)) n = mulmm(p,n)
p = mulmm(p,p)
m = rshift(m,1)
end do
end
impure elemental function powmi32(x,i) result(n)
type(modint), intent(in) :: x
integer, intent(in) :: i
type(modint) :: n, p
integer :: m
n = newi(1_8)
p = x
m = i
if (i < 0) then
p = invm(x)
m = abs(i)
end if
do while (m > 0)
if (btest(m,0)) n = mulmm(p,n)
p = mulmm(p,p)
m = rshift(m,1)
end do
end
end module mod_modint
動作確認
動作確認のために以下のようなコードを用意しました。
program test_modint
use mod_modint
implicit none
type(modint) :: a, b, c
call change_modulus(13)
a = modint(8)
write(*,'(a,i0)') "a = ", a%get()
b = 22_8
write(*,'(a,i0)') "b = ", b%get()
c = -5
write(*,'(a,i0)') "c = ", c%get()
c = -a
write(*,'(a,i0)') "-a = ", c%get()
c = a+b
write(*,'(a,i0)') "a + b = ", c%get()
c = a-b
write(*,'(a,i0)') "a - b = ", c%get()
c = a*b
write(*,'(a,i0)') "a * b = ", c%get()
c = a/b
write(*,'(a,i0)') "a / b = ", c%get()
c = a**5
write(*,'(a,i0)') "a ** 5 = ", c%get()
c = a**(-5_8)
write(*,'(a,i0)') "a ** -5_8 = ", c%get()
c = a+11
write(*,'(a,i0)') "a + 11 = ", c%get()
c = 17-a
write(*,'(a,i0)') "17 - a = ", c%get()
c = a*9_8
write(*,'(a,i0)') "a * 9_8 = ", c%get()
c = a/4_8
write(*,'(a,i0)') "a / 4_8 = ", c%get()
end program test_modint
実行結果は以下のとおりです。
a = 8
b = 9
c = 8
-a = 5
a + b = 4
a - b = 12
a * b = 7
a / b = 11
a ** 5 = 8
a ** -5_8 = 5
a + 11 = 6
17 - a = 9
a * 9_8 = 7
a / 4_8 = 2
上から順に見ていきます。
法の変更
call change_modulus(13)
modint の法 modulus
は、デフォルトでは 1000000007 としていますが、途中で変更したい場合は change_modulus()
で変更できます。今回の場合、 1000000007 は大きすぎるので、わかりやすいように 13 に変更しています。この後除算を行うので、 13 を法とした場合の逆元を記しておきます。
元の数 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
逆元 | 1 | 7 | 9 | 10 | 8 | 11 | 2 | 5 | 3 | 4 | 6 | 12 |
コンストラクタ・代入
a = modint(8)
b = 22_8
c = -5
a = 8
b = 9
c = 8
a = modint(8)
では、変数 a
にコンストラクタ modint()
で生成された modint を代入しています。
b = 22_8
では、変数 b
に integer(8) である 22_8
を代入しています。本来ならばこのような書き方は、 b
と 22_8
の型が異なるのでできません (以下のコンパイル時のメッセージを参照) が、今回のプログラムでは integer(8)・integer(4) に対する代入演算子のオーバーロードを行なっているのでコンパイルが通ります。
test_modint.f08:8:6:
b = 22_8
1
Error: Can't convert INTEGER(8) to TYPE(modint) at (1)
また、 b
に代入している 22_8
は法である 13 以上の数なので、剰余演算が行われて 9_8
となっています。
c = -5
では、変数 c
に integer(4) である -5
を代入しています。これも b
の場合と同様です。また、 c
に代入している -5
は負の数なので、 0 以上 13 未満になるように剰余計算され、 integer(4) の場合は integer(8) に変換されるので 8_8
となります。
四則演算
c = -a
c = a+b
c = a-b
c = a*b
c = a/b
-a = 5
a + b = 4
a - b = 12
a * b = 7
a / b = 11
c = -a
は、単項演算子 -
の例です。この演算では 0 以下の数になるため、自動的に 0 以上 13 未満になるように modulus
が足されます。
c = a+b
、 c = a-b
、 c = a*b
はそれぞれ加算、減算、乗算の例です。これらも 0 以上 13 未満になるように剰余演算が行われます。
c = a/b
は除算の例です。剰余同士の除算は被除数と除数の逆元の乗算となるため、 a (= 8)
と b (= 9)
の除算は 8
と 3
の乗算で計算できます。
冪演算
c = a**5
c = a**(-5_8)
a ** 5 = 8
a ** -5_8 = 5
c = a**5
、 c = a**(-5_8)
は冪乗の例です。指数の型は integer(8) もしくは integer(4) のみ可能です。指数が負の場合は逆元の (指数の絶対値) 乗を返します。計算量のオーダーは $O(\log |指数|)$ です。
integer(8)・integer(4) との互換性
c = a+11
c = 17-a
c = a*9_8
c = a/4_8
a + 11 = 6
17 - a = 9
a * 9_8 = 7
a / 4_8 = 2
modint に integer(8)・integer(4) を代入する際と同様に、四則演算の演算子は integer(8)・integer(4) に対する演算をオーバーロードしているので、上記のような書き方が可能です。
おわりに
modint はソースコードをスッキリさせるのに大いに役立ちます。特に Fortran は剰余演算が %
ではなく mod()
なので、剰余計算は煩雑になりがちなので重宝します。例えば
ans = mod(ans+mod(mod(a*b,md)*mod(c*d,md),md),md)
は modint を用いれば
ans = ans+a*b*c*d
と書けます。嬉しいですね。
ただ一つ文句を言うならば、演算子のオーバーロードをする際に、 Fortran の function・subroutine は integer の kind ごとに書かなければならない点が煩雑です。 Fortran2003 で導入された class(*) を用いようとすると (modint) + (modint)
の場合と (modint) + (integer)
の場合を区別できずにコンパイルエラーがおきます。 integer の kind をいっしょくたに扱えれば簡潔に書けるのですが…。
念のため、このプログラムがきちんと実装できているかを確認するために、以下の問題を modint を用いて解きました。
AtCoder Beginner Contest 110 : D - Factorization
その他の参考
Pythonでmodintを実装してみた
modint 構造体を使ってみませんか? (C++)
追記 2020/05/16
使用している Mac の OS をアップデートしたのに伴い、 gfortran をアップデートしたことにより class(*) を用いてもコンパイラが自動的に判断してくれるようになったのか、より簡素に書けるようになりました。おかげでソースコードの記述量を約半分から三分の一程度減らすことができました(旧ソースコード、新ソースコード)。 新しい方の modint を AtCoder で使う場合は、ジャッジシステムのアップデート後の Fortran(GNU Fortran 9.2.1) でないと使えないので注意してください。
新しい環境
macOS Mojave (10.14.6)
GNU Fortran (GCC) 8.2.0