はじめに
拡張ユークリッドの互除法による逆元の計算はしばしば必要になりますが,そのコードは再帰を使ったり変数が増えたり複雑になりがちだと思います.
ここでは拡張ユークリッドの互除法の変形を用いたシンプルな逆元の計算アルゴリズム紹介したいと思います.
この記事は拡張ユークリッドの互除法が何かを知っていることを前提とします.
概要
通常の拡張ユークリッドの互除法では入力 $a, N$ が与えられたとき $as + Nt=u$ と $ax+Ny=z$ を互除法にかけ変数 $(s,t,u),\ (x, y, z)$ の更新を交互に繰り返します.
今回紹介する変形のポイントは,一つ目の式を $s = 0,\ t = 1,\ u = N$ と固定してしまい二つ目の式のみを更新するようにすることです.
逆元の計算には $t,\ y$ が必要ないとしてもよいことに注意すると実装は以下のようになります.
def inv(a, N):
x = 1
z = a
while z > 1:
q = N // z
z = N - q * z
x = -x * q % N
return x
# 83 * 34 % 91 = 1 を出力
print(f"{inv(34, 91)} * 34 % 91 = {inv(34, 91) * 34 % 91}")
正当性の確認
$z \leftarrow N - \lfloor \frac{N}{z} \rfloor\times z = N \bmod z$ と更新していることからわかるように $z$ は単調減少で, $N$が素数で,$a$ と $N$ が互いに素であればアルゴリズムは必ず停止し,またこのアルゴリズムが終了したときは正しい出力を返すことは簡単に確認できます.
計算量
計算量は (私が考えた限りでは) わからないです.
OEIS を見るに $O(\log N$) 時間で推移してくれそうな気がします.
ランダムケースで走らせたところ拡張ユークリッドの互除法より速い様です.
Wandbox
Fermat : result=4990025251834694 in 3.82656 seconds
Euclid : result=4990025251834694 in 3.32045 seconds
Simple Euclid : result=4990025251834694 in 2.52363 seconds
実験コード
#include <chrono>
#include <iostream>
#include <random>
#include <utility>
#include <vector>
long long modinv_1(long long a, long long N) {
long long r = N - 2, b = 1;
while (r > 0) {
if (r % 2 == 1) {
b = (b * a) % N;
}
a = (a * a) % N;
r >>= 1;
}
return b;
}
long long modinv_2(long long a, long long N) {
long long s = 1, g = a, t = 0, w = N;
while (w > 0) {
long long q = g / w;
s -= q * t;
g -= q * w;
std::swap(s, t);
std::swap(g, w);
}
if (s < 0) s += N;
return s;
}
long long modinv_3(long long a, long long N) {
long long b = 1;
while (a > 1) {
long long q = N / a;
a = N - a * q;
b = -b * q % N;
}
if (b < 0) b += N;
return b;
}
struct Timer {
std::chrono::time_point<std::chrono::high_resolution_clock> start;
Timer() : start(std::chrono::high_resolution_clock::now()) {}
void reset() { start = std::chrono::high_resolution_clock::now(); }
double seconds() {
return std::chrono::duration_cast<std::chrono::duration<double>>(
std::chrono::high_resolution_clock::now() - start)
.count();
}
};
int main() {
const int N = 10000000;
const long long MOD = 998244353;
std::vector<long long> a(N);
std::mt19937 mt(0);
std::uniform_int_distribution<long long> dist(1, MOD - 1);
for (int i = 0; i < N; ++i) {
a[i] = dist(mt);
}
{
Timer timer;
long long s = 0;
for (long long i : a) {
s += modinv_1(i, MOD);
}
double t = timer.seconds();
std::cout << "Fermat\t\t: result=" << s << " in " << t << " seconds"
<< std::endl;
}
{
Timer timer;
long long s = 0;
for (long long i : a) {
s += modinv_2(i, MOD);
}
double t = timer.seconds();
std::cout << "Euclid\t\t: result=" << s << " in " << t << " seconds"
<< std::endl;
}
{
Timer timer;
long long s = 0;
for (long long i : a) {
s += modinv_3(i, MOD);
}
double t = timer.seconds();
std::cout << "Simple Euclid\t: result=" << s << " in " << t << " seconds"
<< std::endl;
}
}
追記その1 (計算量の改善)
次の $z$ を決める際にここまでは $z \leftarrow N \bmod z$ としていましたが,$r \equiv N (\bmod z)$ を満たす $r$ のうち $|r|$ が最小のものをとって $z \leftarrow r$ とすることにします.
こうすることで必ず $z$ の絶対値は $\frac{1}{2}$ 以下になるので計算量は $O(\log N)$ が保証されます.
def inv(a, N):
x = 1
z = a
while abs(z) > 1:
q = (N + (z >> 1)) // z
z = N - q * z
x = -x * q % N
if z < 0:
x = N - x
return x
long long inv(long long a, long long N) {
long long x = 1;
long long z = a;
while (std::abs(z) > 1) {
long long q = (N + (std::abs(z) >> 1)) / z;
z = N - q * z;
x = -x * q % N;
}
if (z < 0) x = -x;
if (x < 0) x += N;
return x;
}
追記その2 (テーブル化)
$x$ に掛けられる値は $z$ のみに依存するため,よく使う $z$ についてこれから $x$ に掛けられる値をあらかじめ求めておくと計算時間が短縮できます.
実装例は下の実験コードを参照してください.
Fermat : result=4990025251834694 in 3.83697 seconds
Euclid : result=4990025251834694 in 3.30284 seconds
Simple Euclid : result=4990025251834694 in 2.51167 seconds
Simple Euclid 2 : result=4990025251834694 in 1.86981 seconds
Table Euclid : result=4990025251834694 in 1.07563 seconds
実験コード
#include <chrono>
#include <cmath>
#include <iostream>
#include <random>
#include <utility>
#include <vector>
long long modinv_1(long long a, long long N) {
long long r = N - 2, b = 1;
while (r > 0) {
if (r % 2 == 1) {
b = (b * a) % N;
}
a = (a * a) % N;
r >>= 1;
}
return b;
}
long long modinv_2(long long a, long long N) {
long long s = 1, g = a, t = 0, w = N;
while (w > 0) {
long long q = g / w;
s -= q * t;
g -= q * w;
std::swap(s, t);
std::swap(g, w);
}
if (s < 0) s += N;
return s;
}
long long modinv_3(long long a, long long N) {
long long b = 1;
while (a > 1) {
long long q = N / a;
a = N - a * q;
b = -b * q % N;
}
if (b < 0) b += N;
return b;
}
long long modinv_4(long long a, long long N) {
long long b = 1;
while (std::abs(a) > 1) {
long long q = (N + (std::abs(a) >> 1)) / a;
a = N - a * q;
b = -b * q % N;
}
if (a < 0) b = -b;
if (b < 0) b += N;
return b;
}
std::vector<long long> gen_table(long long a, long long N) {
std::vector<long long> table(a);
table[0] = 1;
table[1] = 1;
for (int i = 2; i < a; i++) {
table[i] = (N - table[N % i]) * (N / i) % N;
}
return table;
}
long long modinv_5(long long a, long long N,
const std::vector<long long>& table) {
long long b = 1;
while (std::abs(a) >= (int)table.size()) {
long long q = (N + (std::abs(a) >> 1)) / a;
a = N - a * q;
b = -b * q % N;
}
b = b * table[std::abs(a)] % N;
if (a < 0) b = -b;
if (b < 0) b += N;
return b;
}
struct Timer {
std::chrono::time_point<std::chrono::high_resolution_clock> start;
Timer() : start(std::chrono::high_resolution_clock::now()) {}
void reset() { start = std::chrono::high_resolution_clock::now(); }
double seconds() {
return std::chrono::duration_cast<std::chrono::duration<double>>(
std::chrono::high_resolution_clock::now() - start)
.count();
}
};
int main() {
const int N = 10000000;
const long long MOD = 998244353;
std::vector<long long> a(N);
std::mt19937 mt(0);
std::uniform_int_distribution<long long> dist(1, MOD - 1);
for (int i = 0; i < N; ++i) {
a[i] = dist(mt);
}
{
Timer timer;
long long s = 0;
for (long long i : a) {
s += modinv_1(i, MOD);
}
double t = timer.seconds();
std::cout << "Fermat\t\t: result=" << s << " in " << t << " seconds"
<< std::endl;
}
{
Timer timer;
long long s = 0;
for (long long i : a) {
s += modinv_2(i, MOD);
}
double t = timer.seconds();
std::cout << "Euclid\t\t: result=" << s << " in " << t << " seconds"
<< std::endl;
}
{
Timer timer;
long long s = 0;
for (long long i : a) {
s += modinv_3(i, MOD);
}
double t = timer.seconds();
std::cout << "Simple Euclid\t: result=" << s << " in " << t
<< " seconds" << std::endl;
}
{
Timer timer;
long long s = 0;
for (long long i : a) {
s += modinv_4(i, MOD);
}
double t = timer.seconds();
std::cout << "Simple Euclid 2\t: result=" << s << " in " << t
<< " seconds" << std::endl;
}
{
Timer timer;
std::vector<long long> table = gen_table(1000000, MOD);
long long s = 0;
for (long long i : a) {
s += modinv_5(i, MOD, table);
}
double t = timer.seconds();
std::cout << "Table Euclid\t: result=" << s << " in " << t << " seconds"
<< std::endl;
}
}
参考文献
R. クランドール・C. ポメランス (2010) 『素数全書』(和田秀男 監訳) p479 朝倉書店