LoginSignup
3
0

More than 3 years have passed since last update.

Fortran で modint (自動で剰余演算される整数) を実装してみた

Last updated at Posted at 2019-08-22

はじめに

競技プログラミング界隈で用いられている modint (自動で剰余演算される整数) を Fortran で実装しました。実装は AtCoder の解説放送 (ABC 第 129, 130, 133 回) (リンク先は YouTube の動画なので音量に注意してください) を参考にしました。

Fortran で実装した理由としては、 Fortran での実装例が見当たらなかったのと、最近触り始めた Go 言語で modint を実装しようとした際に、 Go 言語が演算子のオーバーロードをサポートしていない事実に気がついたためです。

環境

macOS Sierra (10.12.6)
GNU Fortran (GCC) 6.3.0

プログラム

実際に実装したものが以下のコードです。通常の integer との互換性を持たせるために少々長くなってしまいましたが、演算子をオーバーロードしているおかげで integer と同じ感覚で演算可能です。

modint.f08
module mod_modint
  implicit none
  integer(8) :: modulus = 1000000007_8
  ! integer(8) :: modulus = 998244353_8
  ! integer(8) :: modulus = 1000000009_8

  type modint
    private
    integer(8) :: num
  contains
    procedure :: get => getnum
  end type

  interface change_modulus
    module procedure :: change_modulus64, change_modulus32
  end interface change_modulus

  interface modint
    module procedure :: newm, newi
  end interface modint

  interface assignment(=)
    module procedure :: setm, seti64, seti32
  end interface assignment(=)

  interface operator(+)
    module procedure :: posm, addmm, addmi64, addmi32, addi64m, addi32m
  end interface operator(+)

  interface operator(-)
    module procedure :: negm, submm, submi64, submi32, subi64m, subi32m
  end interface operator(-)

  interface operator(*)
    module procedure :: mulmm, mulmi64, mulmi32, muli64m, muli32m
  end interface operator(*)

  interface operator(/)
    module procedure :: divmm, divmi64, divmi32, divi64m, divi32m
  end interface operator(/)

  interface operator(**)
    module procedure :: powmi64, powmi32
  end interface operator(**)

  interface inv
    module procedure :: invm
  end interface inv

contains

  subroutine change_modulus64(newmod)
    integer(8), intent(in) :: newmod
    if (newmod == 0_8) then
      write(*,'(a)') "Error: Invalid value (newmod == 0). (modint, change_modulus64)"
      stop
    end if
    modulus = newmod
  end

  subroutine change_modulus32(newmod)
    integer, intent(in) :: newmod
    if (newmod == 0) then
      write(*,'(a)') "Error: Invalid value (newmod == 0). (modint, change_modulus32)"
      stop
    end if
    modulus = int(newmod,8)
  end

  integer(8) function getnum(this)
    class(modint), intent(in) :: this
    getnum = this%num
  end

  pure elemental type(modint) function newm()
    newm%num = 0_8
  end

  pure elemental subroutine setm(x,y)
    type(modint), intent(inout) :: x
    type(modint), intent(in) :: y
    x%num = y%num
  end

  pure elemental function posm(x) result(n)
    type(modint), intent(in) :: x
    type(modint) :: n
    n%num = x%num
  end

  pure elemental function negm(x) result(n)
    type(modint), intent(in) :: x
    type(modint) :: n
    n%num = modulus-x%num
  end

  pure elemental function addmm(x,y) result(n)
    type(modint), intent(in) :: x, y
    type(modint) :: n
    n%num = x%num+y%num
    if (n%num >= modulus) n%num = n%num-modulus
  end

  pure elemental function submm(x,y) result(n)
    type(modint), intent(in) :: x, y
    type(modint) :: n
    n%num = x%num-y%num
    if (n%num < 0_8) n%num = n%num+modulus
  end

  pure elemental function mulmm(x,y) result(n)
    type(modint), intent(in) :: x, y
    type(modint) :: n
    n%num = mod(x%num*y%num,modulus)
  end

  impure elemental function invm(x) result(n)
    type(modint), intent(in) :: x
    type(modint) :: n
    integer(8) :: a, b, c, q, r, v
    a = x%num
    if (a == 0_8) then
      write(*,'(a)') "Error: Division by zero (x == 0). (modint, invm)"
      stop
    end if
    b = modulus
    c = 0_8
    v = 1_8
    do while (b /= 0_8)
      q = a/b
      r = mod(a,b)
      a = b
      b = r
      r = c
      c = v-c*q
      v = r
    end do
    n%num = mod(v,modulus)
    if (n%num < 0_8) n%num = n%num+modulus
  end

  impure elemental function divmm(x,y) result(n)
    type(modint), intent(in) :: x, y
    type(modint) :: n
    n = mulmm(x,invm(y))
  end

  !##########################################################################
  !##################### overload with (normal) integer #####################
  !##########################################################################

  impure elemental type(modint) function newi(i)
    class(*), intent(in) :: i
    select type(i)
    type is (integer(8))
      newi%num = i
    type is (integer)
      newi%num = int(i,8)
    type is (integer(2))
      newi%num = int(i,8)
    type is (integer(1))
      newi%num = int(i,8)
    class default
      write(*,'(a)') "Error: Invalid value (i is not integer). (modint, newi)"
      stop
    end select
    newi%num = mod(newi%num,modulus)
    if (newi%num < 0_8) newi%num = newi%num+modulus
  end

  impure elemental subroutine seti64(x,i)
    type(modint), intent(inout) :: x
    integer(8), intent(in) :: i
    call setm(x,newi(i))
  end

  impure elemental subroutine seti32(x,i)
    type(modint), intent(inout) :: x
    integer, intent(in) :: i
    call setm(x,newi(i))
  end

  impure elemental function addmi64(x,i) result(n)
    type(modint), intent(in) :: x
    integer(8), intent(in) :: i
    type(modint) :: n
    n = addmm(x,newi(i))
  end

  impure elemental function addmi32(x,i) result(n)
    type(modint), intent(in) :: x
    integer, intent(in) :: i
    type(modint) :: n
    n = addmm(x,newi(i))
  end

  impure elemental function addi64m(i,y) result(n)
    integer(8), intent(in) :: i
    type(modint), intent(in) :: y
    type(modint) :: n
    n = addmm(newi(i),y)
  end

  impure elemental function addi32m(i,y) result(n)
    integer, intent(in) :: i
    type(modint), intent(in) :: y
    type(modint) :: n
    n = addmm(newi(i),y)
  end

  impure elemental function submi64(x,i) result(n)
    type(modint), intent(in) :: x
    integer(8), intent(in) :: i
    type(modint) :: n
    n = submm(x,newi(i))
  end

  impure elemental function submi32(x,i) result(n)
    type(modint), intent(in) :: x
    integer, intent(in) :: i
    type(modint) :: n
    n = submm(x,newi(i))
  end

  impure elemental function subi64m(i,y) result(n)
    integer(8), intent(in) :: i
    type(modint), intent(in) :: y
    type(modint) :: n
    n = submm(newi(i),y)
  end

  impure elemental function subi32m(i,y) result(n)
    integer, intent(in) :: i
    type(modint), intent(in) :: y
    type(modint) :: n
    n = submm(newi(i),y)
  end

  impure elemental function mulmi64(x,i) result(n)
    type(modint), intent(in) :: x
    integer(8), intent(in) :: i
    type(modint) :: n
    n = mulmm(x,newi(i))
  end

  impure elemental function mulmi32(x,i) result(n)
    type(modint), intent(in) :: x
    integer, intent(in) :: i
    type(modint) :: n
    n = mulmm(x,newi(i))
  end

  impure elemental function muli64m(i,y) result(n)
    integer(8), intent(in) :: i
    type(modint), intent(in) :: y
    type(modint) :: n
    n = mulmm(newi(i),y)
  end

  impure elemental function muli32m(i,y) result(n)
    integer, intent(in) :: i
    type(modint), intent(in) :: y
    type(modint) :: n
    n = mulmm(newi(i),y)
  end

  impure elemental function divmi64(x,i) result(n)
    type(modint), intent(in) :: x
    integer(8), intent(in) :: i
    type(modint) :: n
    n = divmm(x,newi(i))
  end

  impure elemental function divmi32(x,i) result(n)
    type(modint), intent(in) :: x
    integer, intent(in) :: i
    type(modint) :: n
    n = divmm(x,newi(i))
  end

  impure elemental function divi64m(i,y) result(n)
    integer(8), intent(in) :: i
    type(modint), intent(in) :: y
    type(modint) :: n
    n = divmm(newi(i),y)
  end

  impure elemental function divi32m(i,y) result(n)
    integer, intent(in) :: i
    type(modint), intent(in) :: y
    type(modint) :: n
    n = divmm(newi(i),y)
  end

  impure elemental function powmi64(x,i) result(n)
    type(modint), intent(in) :: x
    integer(8), intent(in) :: i
    type(modint) :: n, p
    integer(8) :: m
    n = newi(1_8)
    p = x
    m = i
    if (i < 0_8) then
      p = invm(x)
      m = abs(i)
    end if
    do while (m > 0_8)
      if (btest(m,0)) n = mulmm(p,n)
      p = mulmm(p,p)
      m = rshift(m,1)
    end do
  end

  impure elemental function powmi32(x,i) result(n)
    type(modint), intent(in) :: x
    integer, intent(in) :: i
    type(modint) :: n, p
    integer :: m
    n = newi(1_8)
    p = x
    m = i
    if (i < 0) then
      p = invm(x)
      m = abs(i)
    end if
    do while (m > 0)
      if (btest(m,0)) n = mulmm(p,n)
      p = mulmm(p,p)
      m = rshift(m,1)
    end do
  end

end module mod_modint

動作確認

動作確認のために以下のようなコードを用意しました。

test_modint.f08
program test_modint
  use mod_modint
  implicit none
  type(modint) :: a, b, c
  call change_modulus(13)
  a = modint(8)
  write(*,'(a,i0)') "a         = ", a%get()
  b = 22_8
  write(*,'(a,i0)') "b         = ", b%get()
  c = -5
  write(*,'(a,i0)') "c         = ", c%get()
  c = -a
  write(*,'(a,i0)') "-a        = ", c%get()
  c = a+b
  write(*,'(a,i0)') "a + b     = ", c%get()
  c = a-b
  write(*,'(a,i0)') "a - b     = ", c%get()
  c = a*b
  write(*,'(a,i0)') "a * b     = ", c%get()
  c = a/b
  write(*,'(a,i0)') "a / b     = ", c%get()
  c = a**5
  write(*,'(a,i0)') "a ** 5    = ", c%get()
  c = a**(-5_8)
  write(*,'(a,i0)') "a ** -5_8 = ", c%get()
  c = a+11
  write(*,'(a,i0)') "a + 11    = ", c%get()
  c = 17-a
  write(*,'(a,i0)') "17 - a    = ", c%get()
  c = a*9_8
  write(*,'(a,i0)') "a * 9_8   = ", c%get()
  c = a/4_8
  write(*,'(a,i0)') "a / 4_8   = ", c%get()
end program test_modint

実行結果は以下のとおりです。

a         = 8
b         = 9
c         = 8
-a        = 5
a + b     = 4
a - b     = 12
a * b     = 7
a / b     = 11
a ** 5    = 8
a ** -5_8 = 5
a + 11    = 6
17 - a    = 9
a * 9_8   = 7
a / 4_8   = 2

上から順に見ていきます。

法の変更

call change_modulus(13)

modint の法 modulus は、デフォルトでは 1000000007 としていますが、途中で変更したい場合は change_modulus() で変更できます。今回の場合、 1000000007 は大きすぎるので、わかりやすいように 13 に変更しています。この後除算を行うので、 13 を法とした場合の逆元を記しておきます。

元の数 1 2 3 4 5 6 7 8 9 10 11 12
逆元 1 7 9 10 8 11 2 5 3 4 6 12

コンストラクタ・代入

a = modint(8)
b = 22_8
c = -5
出力結果
a         = 8
b         = 9
c         = 8

a = modint(8) では、変数 a にコンストラクタ modint() で生成された modint を代入しています。
b = 22_8 では、変数 b に integer(8) である 22_8 を代入しています。本来ならばこのような書き方は、 b22_8 の型が異なるのでできません (以下のコンパイル時のメッセージを参照) が、今回のプログラムでは integer(8)・integer(4) に対する代入演算子のオーバーロードを行なっているのでコンパイルが通ります。

コンパイル時エラーメッセージ
test_modint.f08:8:6:

   b = 22_8

      1
Error: Can't convert INTEGER(8) to TYPE(modint) at (1)

また、 b に代入している 22_8 は法である 13 以上の数なので、剰余演算が行われて 9_8 となっています。
c = -5 では、変数 c に integer(4) である -5 を代入しています。これも b の場合と同様です。また、 c に代入している -5 は負の数なので、 0 以上 13 未満になるように剰余計算され、 integer(4) の場合は integer(8) に変換されるので 8_8 となります。

四則演算

c = -a
c = a+b
c = a-b
c = a*b
c = a/b
出力結果
-a        = 5
a + b     = 4
a - b     = 12
a * b     = 7
a / b     = 11

c = -a は、単項演算子 - の例です。この演算では 0 以下の数になるため、自動的に 0 以上 13 未満になるように modulus が足されます。
c = a+bc = a-bc = a*b はそれぞれ加算、減算、乗算の例です。これらも 0 以上 13 未満になるように剰余演算が行われます。
c = a/b は除算の例です。剰余同士の除算は被除数と除数の逆元の乗算となるため、 a (= 8)b (= 9) の除算は 83 の乗算で計算できます。

冪演算

c = a**5
c = a**(-5_8)
出力結果
a ** 5    = 8
a ** -5_8 = 5

c = a**5c = a**(-5_8) は冪乗の例です。指数の型は integer(8) もしくは integer(4) のみ可能です。指数が負の場合は逆元の (指数の絶対値) 乗を返します。計算量のオーダーは $O(\log |指数|)$ です。

integer(8)・integer(4) との互換性

c = a+11
c = 17-a
c = a*9_8
c = a/4_8
出力結果
a + 11    = 6
17 - a    = 9
a * 9_8   = 7
a / 4_8   = 2

modint に integer(8)・integer(4) を代入する際と同様に、四則演算の演算子は integer(8)・integer(4) に対する演算をオーバーロードしているので、上記のような書き方が可能です。

おわりに

modint はソースコードをスッキリさせるのに大いに役立ちます。特に Fortran は剰余演算が % ではなく mod() なので、剰余計算は煩雑になりがちなので重宝します。例えば

ans = mod(ans+mod(mod(a*b,md)*mod(c*d,md),md),md)

は modint を用いれば

ans = ans+a*b*c*d

と書けます。嬉しいですね。

ただ一つ文句を言うならば、演算子のオーバーロードをする際に、 Fortran の function・subroutine は integer の kind ごとに書かなければならない点が煩雑です。 Fortran2003 で導入された class(*) を用いようとすると (modint) + (modint) の場合と (modint) + (integer) の場合を区別できずにコンパイルエラーがおきます。 integer の kind をいっしょくたに扱えれば簡潔に書けるのですが…。

念のため、このプログラムがきちんと実装できているかを確認するために、以下の問題を modint を用いて解きました。

AtCoder Beginner Contest 110 : D - Factorization

その他の参考

Pythonでmodintを実装してみた
modint 構造体を使ってみませんか? (C++)

追記 2020/05/16

使用している Mac の OS をアップデートしたのに伴い、 gfortran をアップデートしたことにより class(*) を用いてもコンパイラが自動的に判断してくれるようになったのか、より簡素に書けるようになりました。おかげでソースコードの記述量を約半分から三分の一程度減らすことができました(旧ソースコード新ソースコード)。 新しい方の modint を AtCoder で使う場合は、ジャッジシステムのアップデート後の Fortran(GNU Fortran 9.2.1) でないと使えないので注意してください。

新しい環境

macOS Mojave (10.14.6)
GNU Fortran (GCC) 8.2.0

3
0
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
0