久々の競プロ記事です。
今回は繰り返し2乗法についてです。Javaで書いてみます。また、繰り返し2乗法を使用した問題をJava, Pythonで解いてみます。
繰り返し2乗法とは
べき乗の計算量を減らすテクニックです。以下のように、指数を2のべき乗表記をして累乗計算をします。
N乗の計算がO(N)からO(logN)になります。
計算方法
3^10を求めます。
10 = 2^3 + 2^1と表せるため、
3^10 = 3^(2^3 + 2^1) = 3^(2^3) * 3^(2^1)と表せます。
きれいに書くと以下のとおりです。
3^{10} = 3^{2^3} * 3^{2^1}
「それはわかるけどなんでそう考えると計算量が減るの?」という声が聞こえて来そうです。
もうちょっと丁寧に説明しますね。
2進数のbit演算を使用しています。
こんな感じです↓
計算の過程は以下のとおりです。
tmp・・・一時変数
ans・・・3^10の解が入る変数
# | 今参照しているbit | 説明 |
---|---|---|
1 | 最下位bit | bitが0のため、ansの更新はしない(ans=1) tmpにtmpをかけて9とする(3*3=9) |
2 | 下から2番目 | bitが1のため、ans*=tmpとする(ans=19=9) tmpにtmpをかけて81とする(99=81) |
3 | 下から3番目 | bitが0のため、ansの更新はしない(ans=9) tmpにtmpをかけて6561とする(81*81=6561) |
4 | 下から4番目 | bitが1のため、ansの更新をする(ans=9*6561=59049) |
こんな感じです。
tmpが重要ですね。最初は3だったのですが、上位bitを参照していくたびに2乗ずつして増えていきます。
2進数の桁が増えるのと同様、tmpも増えていく感じです。
実装
では、実装してみます。
public static long myPow(long a, long n) {
// a^nを計算
long ans = 1l;
long tmp = a;
// わかりやすくfor文の中ですべて処理
for (;;) {
// すべての桁を見終わったら終了
if (n < 1l) {
break;
}
// 最下位bitが1かどうかの判定
if (n % 2l == 1l) {
ans *= tmp;
}
// tmpの更新
tmp *= tmp;
// nのbitを一つずらす
n = n >> 1;
}
return ans;
}
はい。こんな感じです。
myPow(3,10)
を呼ぶとしっかり59049が返ってきます。
問題演習
2021/04/17(土)のAtCoder、第二回日本最強プログラマー学生選手権のD問題、Nowhere Pを解いてみます。
問題はリンクのとおりですが、答えは
(P-1)(P-2)^{n-1} mod 1000000007
を計算すれば良いです。
ただ、nの制約は最高で10^9で、単純にやるとTLEになる可能性が高いです。
なので繰り返し2乗法の登場です。
Javaの回答例
import java.util.Scanner;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
long n = sc.nextLong();
long p = sc.nextLong();
long MOD = 1000000007l;
long ans = (p - 1) * modPow(p - 2, n - 1, MOD) % MOD;
System.out.println(ans);
}
public static long modPow(long a, long n, long mod) {
// a^nを計算
long ans = 1l;
long tmp = a;
// わかりやすくfor文の中ですべて処理
for (;;) {
// すべての桁を見終わったら終了
if (n < 1l) {
break;
}
// 最下位bitが1かどうかの判定
if (n % 2l == 1l) {
ans *= tmp;
ans %= mod;
}
// tmpの更新
tmp *= tmp;
tmp %= mod;
// nのbitを一つずらす
n = n >> 1;
}
return ans;
}
}
先程のmyPow関数をちょっと変えてmodPowにしました。計算の途中でmodとってるだけですね。
これで問題が解けました。
Python の回答例
ちなみに、Pythonのpow関数は繰り返し2乗法が実装されているので特に気にせず実装が出来ます。
また、 powの第三引数に値を入れるとmodPowの実装になります。
実装例(公式の解説とほぼ同じですが)
N, P = map(int, input().split())
MOD = 1000000007
ans = (P-1) * pow(P-2, N-1, MOD) % MOD
print(ans)
こんな感じです。
逆に競プロ以外で繰り返し2乗法や modPowが必要な場面を教えてほしいくらいですが・・・。
割と簡単なアルゴリズムだったので覚えておきたいですね。
繰り返し2乗法を紹介しました。今回の記事はここまでです。