こんにちは。
bitFlyerのAdvent Calendar 9日目の記事です。
初めまして、私はbitFlyerでAndroidエンジニアをしています。
私が担当しているアプリ開発においてはパフォーマンスが非常に重要で、高速化のためのアルゴリズムが用いられている箇所もあり、勉強中です。
ところで皆さんは10^9 + 7
という文字列を見たことがありますでしょうか。
私は、AtCoderをやっていて初めて見ました。
この文字列を見るときは大抵問題が難しいので「この表記がある問題が出てきたら捨てよう」とアレルギー反応が出ていましたが、問題自体は大した難易度ではないものが度々混じっているので非常にもったいないです。
いつかは越えねばならぬ壁なので色々調べましたが、Kotlinで書かれたサンプルコードが全然見つからなかったのでこの記事を書くに至りました。(以前Pythonでやっていた頃、Pythonなら豊富にあったのですが・・・)
この記事に書くこと
-
10^9 + 7
の余りを出力する問題を解く際の四則計算の方法 - 繰り返し二乗法の仕組み
対象読者
- 競プロのビギナー(Atcoderでいうと茶色〜緑色)
-
10^9 + 7
という文字列を見るとアレルギー反応が出てしまう人 - 細かい理屈はともかく、
10^9 + 7
を打破したい人
では早速、まずは四則計算から
足し算
これに関しては特筆すべき点はありません。
const val MOD = 1000000007L
a += b % MOD
a %= MOD
こまめにMODの余りに変換しておきましょう。
引き算
const val MOD = 1000000007L
a -= b % MOD
a += MOD
a %= MOD
負になってしまわないようにMODを足しています。
掛け算
この手の問題だと1<= a <= 10^9
1<= b <= 10^9
といった条件がついているかと思います。
10^9までならLong型同士で掛け算すればせいぜい10^18なのでオーバーフローしないはずです。
const val MOD = 1000000007L
a *= b % MOD
a %= MOD
特にこれに関しても特筆すべき点はなく、オーバーフローしないように計算後にちゃんと割りましょうということくらいです。
難しいのは割り算です。
割り算
const val MOD = 1000000007L
a /= b // NG
a %= MOD
これはNGです。
a(mod 1000000007L)
をb(mod 1000000007L)
で割るのはNGです。
ではどうすれば良いのかというと、まず結論から書くと
const val MOD = 1000000007L
fun main(args: Array<String>) {
a *= b.modPow(MOD - 2)
}
fun Long.modPow(n: Long, mod: Long = 1000000007L): Long {
assert(n >= 0)
if (n == 0L) return 1L
if (n == 1L) return this
if (n % 2L != 0L) return (this * modPow(n - 1, mod)) % mod
val t = modPow(n.shr(1), mod) // n.shr(1)は、2で割っているだけです(右に一つビットシフト)
return (t * t) % mod
}
何をしているのかというと、「aをbで割る」の部分を「aにb^(1000000007L - 2)を掛ける」というコードに書き換えています。(※この変形ができる理由を私は理解していないので「フェルマーの小定理」や「逆元」でググってみてください。)
aをbで割る
の計算部分をaにb^(1000000007L - 2)を掛ける
という計算に置き換えられることはここでは自明として扱って先に進みます。
b^(1000000007L - 2)
を計算するにあたっていくつか注意点があります。
- オーバーフローしないように掛け算の都度
1000000007L
で割ること - 単純に
1000000007L - 2
回の掛け算をしたらTLE待ったなしなので、効率的に計算すること(繰り返し二乗法を用いる)
以上をクリアするLongの拡張メソッドが
fun Long.modPow(n: Long, mod: Long = 1000000007L): Long {
assert(n >= 0)
if (n == 0L) return 1L
if (n == 1L) return this
if (n % 2L != 0L) return (this * modPow(n - 1, mod)) % mod
val t = modPow(n / 2, mod)
return (t * t) % mod
}
こちらになります。
この拡張メソッドのアルゴリズムはシンプルで
- nが2で割り切れるときは、x^n * x^nの形にしてしまおう
- nが2で割り切れないときは、割り切れる形にしよう
ポイントはこの二点だけです。
例えばx^66
をこのアルゴリズムに当てはめてみると
x^66
= x^33 * x^33
x^33
= x * x^32
x^32
= x^16 * x^16
x^16
= x^8 * x^8
・・・
n乗するときの計算量のオーダーは、O(log N)
かと思います。
これを用いると高速で逆元を求められることがわかりました。
終わりに
Kotlinで競プロをやっている人はあまりいないかもしれませんが、Intellijという強力なIDEを使えますし、急いで書いても予期せぬ型が入ってこないので手戻りが少なくて非常に書きやすいです。
おすすめです!
最後に、弊社bitFlyerではAndroidエンジニアを募集しています。ご興味のある方はぜひご応募ください!