LoginSignup
3
1

More than 3 years have passed since last update.

Tonelli-Shanks アルゴリズムの理解と実装(2)

Last updated at Posted at 2020-10-21

はじめに

前回の記事はこちらです。

実装

まず、今回仮定する条件をもう一度しておきます。

  • Input:
    • $p$ : 奇素数
    • $n$ : 整数, ただし$n$は$p$の倍数ではない
  • Output:
    • $x$ : $x^2 \equiv n\ {\rm mod}\ p$, かつ $0 < x < n$を満たす1

$n$が$p$の倍数なら、$n \equiv 0\ {\rm mod}\ p$で剰余とは言いづらいので今回は除外します(例外としてコード中に組み込みます)。

Legendre記号

$$\left(\begin{array}
nn\\
p
\end{array}\right) := n^{\frac{p-1}{2}} \equiv \begin{cases}
1\ {\rm mod}\ p\Leftrightarrow n は平方剰余\\
-1\ {\rm mod}\ p \Leftrightarrow n は平方非剰余
\end{cases}$$

でした。ここで、$n$が$p$の倍数なら当然
$$\left(\begin{array}
nn\\
p
\end{array}\right) = 0$$
となるので、ここで例外処理をしておきます。

なお、pow(a, b, c)は$a^b \ {\rm mod}\ c$を表します2

def legendre_symbol(n, p):
    ls = pow(n, (p - 1) // 2, p)
    if ls == 1:
        return 1
    # pow関数は0 ~ p-1の範囲で値を返します
    elif ls == p - 1:
        return -1
    else:
        # ls == 0、つまりnがpの倍数の場合
        raise Exception('n:{} = 0 mod p:{}'.format(n, p))

pが4で割って3余る素数の場合

$$x = \pm n^{\frac{p+1}{4}}$$

が答えでした。一応答えが正しいことを確認するcheck_sqrt関数も定義しておきます。

# 基本的にここでassertionエラーは出ないはず
def check_sqrt(x, n, p):
    assert(pow(x, 2, p) == n % p)

def modular_sqrt(n, p):
    if p % 4 == 3:
        x = pow(n, (p + 1) // 4, p)
        check_sqrt(x, n, p)
        return [x, p - x]
    else:
        # これから説明
        pass

pが4で割って1余る素数の場合

ここからTonelli-Shanksのアルゴリズムの実装です。

Step 1.

$$p-1 = Q \cdot 2^S$$

という形に変形します($Q$は奇数で、$S$は正の整数です)。

Pythonでは多くの場合、大文字は定数を意味するので小文字のq, sを使います。

def modular_sqrt(n, p):
    ...
    else:
        # Step 1.
        q, s = p - 1, 0
        while q % 2 == 0:
            q //= 2
            s += 1
    ...

Step 2.

平方非剰余である$z$をランダムに選美ます。

前回も述べたように、半数は平方非剰余なので$2$から総当たりしています。

なぜ$1$から始めないのかというと、$x^2 \equiv 1\ {\rm mod} p$なる$x$は自明に存在($x=1$)して、任意の$p$に対して$1$は平方剰余だからです。

def modular_sqrt(n, p):
    ...
    else:
        # Step 1.
        q, s = p - 1, 0
        while q % 2 == 0:
            q //= 2
            s += 1

        # Step 2.
        z = 2
        while legendre_symbol(z, p) != -1:
            z += 1
    ...

Step 3.

$$\begin{cases}
M_0 = S\\
c_0 = z^Q\\
t_0 = n^Q\\
R_0 = n^{\frac{Q+1}{2}}
\end{cases}$$

これはそのままです。先ほどと同様、すべて小文字で定義します。

def modular_sqrt(n, p):
    ...
    else:
        ...
        # Step 2.
        z = 2
        while legendre_symbol(z, p) != -1:
            z += 1

        # Step 3.
        m, c, t, r = s, pow(z, q, p), pow(n, q, p), pow(n, (q + 1) // 2, p)
    ...

Step 4.

  1. もし$t_i \equiv 1$なら、$\pm R_i$が$n$の平方根であり、ループ文を抜けて終了。

  2. そうでない場合は、以下のように値を更新する。

$$\begin{cases}
M_{i+1} = \left(\left(t_i\right)^{2^{j}}\equiv 1 を満たす最小のj, ただし0 < j < M_i\right)\\
\\
c_{i+1} = \left(c_i\right)^{2^{M_i - M_{i+1}}}\\
\\
t_{i+1} = t_i \cdot \left(c_i\right)^{2^{M_i - M_{i+1}}}\\
\\
R_{i+1} = R_i \cdot \left(c_i\right)^{2^{M_i - M_{i+1}-1}}
\end{cases}$$

さて、これを実装する上で、2つ補足説明をしておきます。

1つ目。

$M_{i+1}$を求める上で、$j = 1, 2, \cdots$と順に代入していく分けですが、「$t_i$を$2$回掛けて$(t_i)^2$を計算し、1になるか調べる。ならなければ$t_i$を$4$回掛けて$(t_i)^4$を計算し、1になるか調べる。…」とする以下のコードは少し無駄が多いと思いませんか?(m_updateが$M_{i+1}$に相当します)

for j in range(1, m):
    if pow(t, pow(2, j), p) == 1:
        m_update = j
        break

せっかく$(t_i)^2$を計算しているなら、それを2乗することで$(t_i)^4$が求められるので再利用しましょう3

pow_t = pow(t, 2, p)
for j in range(1, m):
    if pow_t == 1:
        m_update = j
        break
    pow_t = pow(pow_t, 2, p)

2つ目。

$$b_i = \left(c_i\right)^{2^{M_i - M_{i+1}-1}}$$

と定義すると、値の更新は以下のようにすっきりと書けます。この記号はWikipediaにも記載されています。

$$\begin{cases}
M_{i+1} = \left(\left(t_i\right)^{2^{j}}\equiv 1 を満たす最小のj, ただし0 < j < M_i\right)\\
\\
c_{i+1} = b_i^2\\
\\
t_{i+1} = t_i \cdot b_i^2\\
\\
R_{i+1} = R_i \cdot b_i
\end{cases}$$

前回はこれ以上変数を導入するとややこしいかなと思ったので割愛していました。

以上2点を踏まえて、以下のようにコードが書けます。

def modular_sqrt(n, p):
    ...
    else:
        ...
        # Step 3.
        m, c, t, r = s, pow(z, q, p), pow(n, q, p), pow(n, (q + 1) // 2, p)

        # Step 4.
        while t != 1:
            pow_t = pow(t, 2, p)
            for j in range(1, m):
                if pow_t == 1:
                    m_update = j
                    break
                pow_t = pow(pow_t, 2, p)
            b = pow(c, int(pow(2, m - m_update - 1)), p)
            m, c, t, r = m_update, pow(b, 2, p), t * pow(b, 2, p) % p, r * b % p

        # 答えの確認
        check_sqrt(r, n, p)
        return [r, p - r]

実装上の注意としては、$c_{i+1}, t_{i+1}, R_{i+1}$に更新する際に$M_i$と$M_{i+1}$の両方を用いるので、一旦m_update = jとおいてmをすぐに更新しないことです。

その他

$p$が素数であるかどうかの確認は実は多項式時間で可能です4

from gmpy2 import is_prime

is_prime(p)

あるいは

from Crypto.Util.number import isPrime

isPrime(p)

で高速な素数判定が可能です。

どちらも標準では入っていないモジュールだと思うので、pip3でインストールする必要があります。

全体のソースコード

#!/usr/bin/env python3

from Crypto.Util.number import isPrime
# from gmpy2 import is_prime

def legendre_symbol(n, p):
    ls = pow(n, (p - 1) // 2, p)
    if ls == 1:
        return 1
    elif ls == p - 1:
        return -1
    else:
        # in case ls == 0
        raise Exception('n:{} = 0 mod p:{}'.format(n, p))

def check_sqrt(x, n, p):
    assert(pow(x, 2, p) == n % p)

def modular_sqrt(n:int, p:int) -> list:
    if type(n) != int or type(p) != int:
        raise TypeError('n and p must be integers')

    if p < 3:
        raise Exception('p must be equal to or more than 3')

    if not isPrime(p):
        raise Exception('p must be a prime number. {} is a composite number'.format(p))

    if legendre_symbol(n, p) == -1:
        raise Exception('n={} is Quadratic Nonresidue modulo p={}'.format(n, p))

    if p % 4 == 3:
        x = pow(n, (p + 1) // 4, p)
        check_sqrt(x, n, p)
        return [x, p - x]

    # Tonelli-Shanks
    q, s = p - 1, 0
    while q % 2 == 0:
        q //= 2
        s += 1
    z = 2
    while legendre_symbol(z, p) != -1:
        z += 1
    m, c, t, r = s, pow(z, q, p), pow(n, q, p), pow(n, (q + 1) // 2, p)
    while t != 1:
        pow_t = pow(t, 2, p)
        for j in range(1, m):
            if pow_t == 1:
                m_update = j
                break
            pow_t = pow(pow_t, 2, p)
        b = pow(c, int(pow(2, m - m_update - 1)), p)
        m, c, t, r = m_update, pow(b, 2, p), t * pow(b, 2, p) % p, r * b % p
    check_sqrt(r, n, p)
    return [r, p - r]

print(modular_sqrt(5, 41))
# => [28, 13]

  1. $0 < x < p$で解が見つかれば、その$x$に対して$p$を順々に足して(あるいは引いて)いったものも当然解になるので、今回は$0 < x < p$の範囲に限定しました。 

  2. Pythonであればa ** b % cという書き方もできますが、pow関数を使った方がより高速に動作します。 

  3. 例えば、$0 < j < 10$なら、前者はpow関数の呼び出しは最大18回ですが、後者は9回で済みます。すごくアバウトな計算量評価ですが、こういった観点から後者の方を今回採用しました。ただ、実際はどちらを使っても元々のアルゴリズムが高速なのでほぼ差は出ません。 

  4. 「$p$が素数でない(=合成数である)」と「$p$の素因数分解の結果が分かる」は別です。今回は前者について触れています。後者が多項式時間でできてしまうと、いわゆるRSA暗号などの安全性に対する理論が瓦解してしまいます。(かといって、多項式時間ではできないだろうと言われているだけで、まだ未解決の問題のままです。) 

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1