LoginSignup
2
3

More than 3 years have passed since last update.

[競プロ]繰り返し2乗法【Java】【Python】

Last updated at Posted at 2021-04-21

久々の競プロ記事です。

今回は繰り返し2乗法についてです。Javaで書いてみます。また、繰り返し2乗法を使用した問題をJava, Pythonで解いてみます。

繰り返し2乗法とは

べき乗の計算量を減らすテクニックです。以下のように、指数を2のべき乗表記をして累乗計算をします。
N乗の計算がO(N)からO(logN)になります。

計算方法

3^10を求めます。
10 = 2^3 + 2^1と表せるため、
3^10 = 3^(2^3 + 2^1) = 3^(2^3) * 3^(2^1)と表せます。

きれいに書くと以下のとおりです。

3^{10} = 3^{2^3} * 3^{2^1}

「それはわかるけどなんでそう考えると計算量が減るの?」という声が聞こえて来そうです。
もうちょっと丁寧に説明しますね。
2進数のbit演算を使用しています。
こんな感じです↓

スクリーンショット 2021-04-21 9.05.13.png

計算の過程は以下のとおりです。
tmp・・・一時変数
ans・・・3^10の解が入る変数

# 今参照しているbit 説明
1 最下位bit bitが0のため、ansの更新はしない(ans=1)
tmpにtmpをかけて9とする(3*3=9)
2 下から2番目 bitが1のため、ans*=tmpとする(ans=1*9=9)
tmpにtmpをかけて81とする(9*9=81)
3 下から3番目 bitが0のため、ansの更新はしない(ans=9)
tmpにtmpをかけて6561とする(81*81=6561)
4 下から4番目 bitが1のため、ansの更新をする(ans=9*6561=59049)

こんな感じです。
tmpが重要ですね。最初は3だったのですが、上位bitを参照していくたびに2乗ずつして増えていきます。
2進数の桁が増えるのと同様、tmpも増えていく感じです。

実装

では、実装してみます。

    public static long myPow(long a, long n) {
        // a^nを計算
        long ans = 1l;
        long tmp = a;

        // わかりやすくfor文の中ですべて処理
        for (;;) {
            // すべての桁を見終わったら終了
            if (n < 1l) {
                break;
            }

            // 最下位bitが1かどうかの判定
            if (n % 2l == 1l) {
                ans *= tmp;
            }

            // tmpの更新
            tmp *= tmp;

            // nのbitを一つずらす
            n = n >> 1;
        }
        return ans;
    }

はい。こんな感じです。

myPow(3,10)を呼ぶとしっかり59049が返ってきます。

問題演習

2021/04/17(土)のAtCoder、第二回日本最強プログラマー学生選手権のD問題、Nowhere Pを解いてみます。

問題はリンクのとおりですが、答えは

(P-1)(P-2)^{n-1} mod 1000000007

を計算すれば良いです。
ただ、nの制約は最高で10^9で、単純にやるとTLEになる可能性が高いです。
なので繰り返し2乗法の登場です。

Javaの回答例

import java.util.Scanner;

public class Main {

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        long n = sc.nextLong();
        long p = sc.nextLong();
        long MOD = 1000000007l;

        long ans = (p - 1) * modPow(p - 2, n - 1, MOD) % MOD;

        System.out.println(ans);
    }

    public static long modPow(long a, long n, long mod) {
        // a^nを計算
        long ans = 1l;
        long tmp = a;

        // わかりやすくfor文の中ですべて処理
        for (;;) {
            // すべての桁を見終わったら終了
            if (n < 1l) {
                break;
            }

            // 最下位bitが1かどうかの判定
            if (n % 2l == 1l) {
                ans *= tmp;
                ans %= mod;
            }

            // tmpの更新
            tmp *= tmp;
            tmp %= mod;

            // nのbitを一つずらす
            n = n >> 1;
        }
        return ans;
    }

}

先程のmyPow関数をちょっと変えてmodPowにしました。計算の途中でmodとってるだけですね。
これで問題が解けました。

Python の回答例

ちなみに、Pythonのpow関数は繰り返し2乗法が実装されているので特に気にせず実装が出来ます。
また、 powの第三引数に値を入れるとmodPowの実装になります。

実装例(公式の解説とほぼ同じですが)

N, P = map(int, input().split())
MOD = 1000000007
ans = (P-1) * pow(P-2, N-1, MOD) % MOD
print(ans)

こんな感じです。

逆に競プロ以外で繰り返し2乗法や modPowが必要な場面を教えてほしいくらいですが・・・。

割と簡単なアルゴリズムだったので覚えておきたいですね。

繰り返し2乗法を紹介しました。今回の記事はここまでです。

2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3