はじめに
数値計算において、非常に大きな$n!$や$\log n!$などの階乗の計算は時間がかかる場合があります。そのため、階乗の計算にはスターリングの級数などを利用して近似的に求めることがあります。その方法については以下で紹介しました。
しかし、これらの近似式は処理速度は向上しますが、$n$が20程度の小さい場合において、精度が10桁程度であることがわかり、どんな$n$についても高精度かつ高速な計算が必要な場合には、さらなる検討が必要です。
本記事では、ガンマ関数を用いて階乗の計算を高速化する方法について紹介します。ガンマ関数を利用することで、高精度な計算が可能となります。また、math.factorial()
とmath.gamma()
の比較や多次元配列での階乗の計算方法などについても解説します。
ガンマ関数
ガンマ関数は、階乗を一般化したもので、階乗は自然数に対して定義されますが、ガンマ関数は自然数以外の実数や複素数にも拡張されます。ガンマ関数は以下で表されます。
\Gamma (z)=\int _{0}^{\infty }t^{z-1}e^{-t}\,dt,\ \qquad \Re (z)>0
特に自然数の場合、
n!=\Gamma (n+1)
と書けます。ガンマ関数を使うことで、$n!$の計算が、範囲は0から正の無限大の積分で表現することができます。数値計算において、大きな$n$に対して、積分で表現した方が早い場合があり、その事について実験してみます。
数値実験
Google Colabで作成した本記事のコードは、こちらにあります。
数値実験では、math.facatorial(n)
とmath.gamma(n+1)
を$n=2, 3, 4, 99$について計算し、精度と処理速度について確認します。
各種インポート
import math
import time
import numpy as np
import scipy.special
import matplotlib.pyplot as plt
精度比較
n_range = list(range(2, 100))
factorials = np.zeros((len(n_range), 1))
gammas = np.zeros((len(n_range), 1))
for i, n in enumerate(n_range):
factorials[i] = math.factorial(n)
gammas[i] = math.gamma(n+1)
fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(16, 6))
ax0.set_title('factorial vs. gamma')
ax0.plot(n_range, factorials, label='factorial')
ax0.plot(n_range, gammas, label='gamma')
ax0.set_yscale('log')
ax0.set_xlabel('factorial: n, gamma: n+1')
ax0.set_ylabel('value')
ax0.legend()
ax1.set_title('abs(gamma - factorial)')
ax1.plot(n_range, abs(gammas - factorials))
ax1.set_yscale('log')
ax1.set_xlabel('factorial: n, gamma: n+1')
ax1.set_ylabel('value')
ax2.set_title('abs(gamma - factorial) / factorial')
ax2.plot(n_range, abs(gammas - factorials) / factorials)
ax2.set_xlabel('factorial: n, gamma: n+1')
ax2.set_ylabel('value')
plt.show()
上図は、左から、factorial()
とgamma()
の比較、差の絶対値、差の絶対値の割合を示しています。一番左のグラフから、割合で言えば約16桁の精度を持っていることがわかります。
math.factorial()
関数は整数型の戻り値を返しますが、math.gamma()
関数は浮動小数点数型(float 型)を返すため、15桁の精度が限界となります。つまり、math.gamma()
の結果はfloat型の制約による精度の限界と言えます。
処理速度比較
def processing_func_speed(func, n, iteration):
start_time = time.time()
for _ in range(iteration):
func(n)
end_time = time.time()
elapsed_avg_time = (end_time - start_time) / iteration
return elapsed_avg_time
n_range = list(range(2, 100, 10))
iteration = 500000
factorial_times = np.zeros((len(n_range), 1))
gamma_times = np.zeros((len(n_range), 1))
for i, n in enumerate(n_range):
factorial_times[i] = processing_func_speed(math.factorial, n, iteration)
gamma_times[i] = processing_func_speed(math.gamma, n+1, iteration)
fig, ax = plt.subplots()
ax.plot(n_range, factorial_times, '.-', label='factorial')
ax.plot(n_range, gamma_times, '.-', label='gamma')
ax.legend()
ax.set_ylim(0,)
ax.set_xlabel('factorial: n, gamma: n+1')
ax.set_ylabel('time')
plt.show()
gamma()は、0から無限までの積分を計算するため、n
によらずほとんど一定の速度となります。一方、factoraial()は、厳密に階乗を計算しているためn
に比例して計算量が増加していることがグラフからわかります。
n
が約100程度では、処理速度において10倍程度処理速度の差が生じており、また精度も約16桁あるため、ガンマ関数を使った計算の方が良いと思います。
実際に、gamma()
関数の計算のソースコードの以下を参照すると、
if num <= 0:
raise ValueError("math domain error")
return quad(integrand, 0, inf, args=(num))[0]
def integrand(x: float, z: float) -> float:
return math.pow(x, z - 1) * math.exp(-x)
となり、scipy.integrate.guad()関数を使って、0
からinf
までの積分を行なっていることがわかります。
おまけ
対数の階乗
ここでは、対数の階乗を計算する方法について紹介します。対数の階乗は、階乗を計算した後に対数を取る方法と、対数の階乗が提供されている scipy.special.gammaln()
関数を使って計算する主に2つの方法があります。
精度と処理速度について、math.log(math.factorial(n))
を正しい計算結果として、math.log(math.gamma(n+1))
とscipy.special.gammaln(n+1)
計算します。
def factorial_log(n):
return math.log(math.factorial(n))
def gamma_log(n):
return math.log(math.gamma(n))
精度比較
n_range = list(range(2, 100))
factorial_logs = np.zeros((len(n_range), 1))
gamma_logs = np.zeros((len(n_range), 1))
gammalns = np.zeros((len(n_range), 1))
for i, n in enumerate(n_range):
factorial_logs[i] = math.log(math.factorial(n))
gamma_logs[i] = math.log(math.gamma(n+1))
gammalns[i] = scipy.special.gammaln(n+1)
fig, (ax0, ax1, ax2) = plt.subplots(1, 3, figsize=(16, 6))
ax0.set_title('factorial vs. gamma')
ax0.plot(n_range, factorial_logs, label='factorial')
ax0.plot(n_range, gamma_logs, label='gamma_log')
ax0.plot(n_range, gammalns, label='gammaln')
ax0.set_yscale('log')
ax0.set_xlabel('factorial: n, gamma: n+1')
ax0.set_ylabel('value')
ax0.legend()
ax1.set_title('abs(gamma - factorial)')
ax1.plot(n_range, abs(gamma_logs - factorial_logs), label='gamma_log')
ax1.plot(n_range, abs(gammalns - factorial_logs), label='gammaln')
ax1.set_yscale('log')
ax1.set_xlabel('factorial: n, gamma: n+1')
ax1.set_ylabel('value')
ax1.legend()
ax2.set_title('abs(gamma - factorial) / factorial')
ax2.plot(n_range, abs(gamma_logs - factorial_logs) / factorial_logs, label='gamma_log')
ax2.plot(n_range, abs(gammalns - factorial_logs) / factorial_logs, label='gammaln')
ax2.set_xlabel('factorial: n, gamma: n+1')
ax2.set_ylabel('value')
ax2.legend()
plt.show()
一番左の図から、割合として16桁の精度でfloat型の限界の精度を持っています。
処理速度比較
n_range = list(range(2, 100, 10))
iteration = 500000
factorial_log_times = np.zeros((len(n_range), 1))
gamma_log_times = np.zeros((len(n_range), 1))
gammaln_times = np.zeros((len(n_range), 1))
for i, n in enumerate(n_range):
factorial_log_times[i] = processing_func_speed(factorial_log, n, iteration)
gamma_log_times[i] = processing_func_speed(gamma_log, n+1, iteration)
gammaln_times[i] = processing_func_speed(scipy.special.gammaln, n+1, iteration)
fig, ax = plt.subplots()
ax.plot(n_range, factorial_log_times, '.-', label='factorial_log')
ax.plot(n_range, gamma_log_times, '.-', label='gamma_log')
ax.plot(n_range, gammaln_times, '.-', label='gammaln')
ax.legend()
ax.set_ylim(0,)
ax.set_xlabel('factorial: n, gamma: n+1')
ax.set_ylabel('time')
plt.show()
青色のラインはmath.log(math.factorial(n))
、オレンジ色のラインはmath.log(math.gamma(n+1))
、そして緑色のラインはscipy.special.gammaln(n+1)
を表しています。結果を見ると、scipyのモジュールは高度な計算が可能なため、他の方法に比べて最も処理が遅いことがわかります。
実際に、scipyは数値計算や科学技術計算のための高度な機能を提供していますが、その反面、より複雑な計算や柔軟性を求める場合には計算時間が増える可能性があります。以下に、scipyのモジュールの利点について紹介します。
多次元配列の階乗の計算
多次元配列の階乗の計算はscipyモジュールが便利です。一方で、mathモジュールでは配列を直接受け取ることができないため、そこが一つの差別化ポイントです。
ただし、scipy.special.factorial(array)
は戻り値が、float型なので厳密な計算が必要なときは、math.facatorial()
を使いましょう。以下が実装例です。
array = np.arange(2*3).reshape(2, 3)
# factorial_array = math.factorial(array) # mathモジュールは多次元に対応していないためエラーになる
factorial_array = scipy.special.factorial(array)
gamma_array = scipy.special.gamma(array+1)
print(f'{array=}')
print(f'{factorial_array=}')
print(f'{gamma_array=}')
array=array([[0, 1, 2],
[3, 4, 5]])
factorial_array=array([[ 1., 1., 2.],
[ 6., 24., 120.]])
gamma_array=array([[ 1., 1., 2.],
[ 6., 24., 120.]])
大きな数の階乗の計算
float型で扱える最大値(1.7976931348623157e308
)を超えるような計算は、注意が必要です。整数型math.factorial()
では計算することはできますが、戻り値がfloat型となるmath.gamma()
などでは計算できないです。float型の最大値を超える場合、mathモジュールでは、OverflowError
のエラーを返し、scipyモジュールでは、inf
を返される仕様になっています。
そのため、このような計算を行う必要がある場合、桁数を増やすための特殊なモジュールやライブラリの利用もできますが、整数型が戻り値のmath.factorial()
を使うことが簡単な手段だと思います。
for n in range(165, 175):
# math.gammaは戻り値がfloat型なので171!を超えると計算不可 -> OverflowError: math range error
try:
math_gamma = math.gamma(n+1)
except OverflowError as e:
math_gamma = e
# scipyのgammaは戻り値がfloat型なので171!はfloat型の最大値infを返す
scipy_gamma = scipy.special.gamma(n+1)
# 戻り値がint型なので大きな数も正確に計算可能
math_factorial = math.factorial(n)
print(f'{n=}, {math_gamma=}, {scipy_gamma=}, {math_factorial=}')
n=165, math_gamma=5.42391066613159e+295, scipy_gamma=5.42391066613159e+295, math_factorial=54239106661315887749844950142128418438215720379413698342567269131165730172604765680403290849565015553604650354393935095487058155225915554489271207416431263907819225703574088519286376487014046409943976621391856730220500349076942933314077895545443758280540160000000000000000000000000000000000000000
n=166, math_gamma=9.003691705778436e+297, scipy_gamma=9.00369170577844e+297, math_factorial=9003691705778437366474261723593317460743809582982673924866166675773511208652391102946946281027792581898371958829393225850851653767501982045219020431127589808697991466793298694201538496844331704050700119151048217216603057946772526930136930660543663874569666560000000000000000000000000000000000000000
n=167, math_gamma=1.5036165148649994e+300, scipy_gamma=1.503616514864999e+300, math_factorial=1503616514864999040201201707840084015944216200358106545452649834854176371844949314192140028931641361177028117124508668717092226179172831001551576411998307498052564574954480881931656928973003394576466919898225052275172710677111011997332867420310791867053134315520000000000000000000000000000000000000000
n=168, math_gamma=2.526075744973198e+302, scipy_gamma=2.5260757449731984e+302, math_factorial=252607574497319838753801886917134114678628321660161899636045172255501630469951484784279524860515748677740723676917456344471493998101035608260664837215715659672830848592352788164518364067464570288846442542901808782229015393754650015551921726612213033664926565007360000000000000000000000000000000000000000
n=169, math_gamma=4.269068009004705e+304, scipy_gamma=4.269068009004706e+304, math_factorial=42690680090047052749392518888995665380688186360567361038491634111179775549421800928543239701427161526538182301399050122215682485679075017796052357489455946484708413412107621199803603527401512378815048789750405684196703601544535852628274771797464002689372589486243840000000000000000000000000000000000000000
n=170, math_gamma=7.257415615307998e+306, scipy_gamma=7.257415615308e+306, math_factorial=7257415615307998967396728211129263114716991681296451376543577798900561843401706157852350749242617459511490991237838520776666022565442753025328900773207510902400430280058295603966612599658257104398558294257568966313439612262571094946806711205568880457193340212661452800000000000000000000000000000000000000000
n=171, math_gamma=OverflowError('math range error'), scipy_gamma=inf, math_factorial=1241018070217667823424840524103103992616605577501693185388951803611996075221691752992751978120487585576464959501670387052809889858690710767331242032218484364310473577889968548278290754541561964852153468318044293239598173696899657235903947616152278558180061176365108428800000000000000000000000000000000000000000
n=172, math_gamma=OverflowError('math range error'), scipy_gamma=inf, math_factorial=213455108077438865629072570145733886730056159330291227886899710221263324938130981514753340236723864719151973034287306573083301055694802251980973629541579310661401455397074590303866009781148657954570396550703618437210885875866741044575478989978191912006970522334798649753600000000000000000000000000000000000000000
n=173, math_gamma=OverflowError('math range error'), scipy_gamma=inf, math_factorial=36927733697396923753829554635211962404299715564140382424433649868278555214296659802052327860953228596413291334931704037143411082635200789592708437910693220744422451783693904122568819692138717826140678603271725989637483256524946200711557865266227200777205900363920166407372800000000000000000000000000000000000000000
n=174, math_gamma=OverflowError('math range error'), scipy_gamma=inf, math_factorial=6425425663347064733166342506526881458348150508160426541851455077080468607287618805557105047805861775775912692278116502462953528378524937389131268196460620409529506610362739317326974626432136901748478076969280322196922086635340638923811068556323532935233826663322108954882867200000000000000000000000000000000000000000
大きな数の対数の階乗の計算
float型で扱える最大値(1.7976931348623157e308
)を超えるような対数の階乗の計算は、階乗の計算よりも手段が増え、簡単な方法だと2つの選択肢があり、一つは、整数型が戻り値のmath.factorial()
のlogを取る方法で、もう一つはscipyのscipy.special.gammaln()
を使う方法です。ここで、精度と処理速度を確認しておきます。
# 関数の比較
print('関数の比較')
for n in range(165, 175):
gammaln_val = scipy.special.gammaln(n+1) # math.log(scipy.special.gamma(n+1))はn=170以上では使えない
factorial_log_val = math.log(math.factorial(n))
print(f'{n=}, {gammaln_val=}, {factorial_log_val=}, {abs(gammaln_val-factorial_log_val)/factorial_log_val=}')
# 処理時間の表示
print('処理時間')
iteration = 100000
for n in range(165, 175):
gammaln_time = processing_func_speed(scipy.special.gammaln, n+1, iteration)
factorial_log_time = processing_func_speed(factorial_log, n, iteration)
print(f'{n=}, {gammaln_time=}, {factorial_log_time=}')
関数の比較
n=165, gammaln_val=680.9534195136375, factorial_log_val=680.9534195136374, abs(gammaln_val-factorial_log_val)/factorial_log_val=1.669524441228527e-16
n=166, gammaln_val=686.0654073019941, factorial_log_val=686.065407301994, abs(gammaln_val-factorial_log_val)/factorial_log_val=1.6570845361333467e-16
n=167, gammaln_val=691.1834011144108, factorial_log_val=691.1834011144108, abs(gammaln_val-factorial_log_val)/factorial_log_val=0.0
n=168, gammaln_val=696.307365093814, factorial_log_val=696.307365093814, abs(gammaln_val-factorial_log_val)/factorial_log_val=0.0
n=169, gammaln_val=701.4372638087372, factorial_log_val=701.437263808737, abs(gammaln_val-factorial_log_val)/factorial_log_val=1.6207698619304514e-16
n=170, gammaln_val=706.5730622457875, factorial_log_val=706.5730622457874, abs(gammaln_val-factorial_log_val)/factorial_log_val=1.6089891307244474e-16
n=171, gammaln_val=711.71472580229, factorial_log_val=711.7147258022899, abs(gammaln_val-factorial_log_val)/factorial_log_val=1.5973652588607188e-16
n=172, gammaln_val=716.8622202791036, factorial_log_val=716.8622202791034, abs(gammaln_val-factorial_log_val)/factorial_log_val=1.5858952321040597e-16
n=173, gammaln_val=722.0155118736013, factorial_log_val=722.0155118736012, abs(gammaln_val-factorial_log_val)/factorial_log_val=1.5745761116212485e-16
n=174, gammaln_val=727.1745671728158, factorial_log_val=727.1745671728158, abs(gammaln_val-factorial_log_val)/factorial_log_val=0.0
処理時間
n=165, gammaln_time=1.6943073272705078e-06, factorial_log_time=2.2293972969055177e-06
n=166, gammaln_time=1.2653350830078125e-06, factorial_log_time=2.3422527313232423e-06
n=167, gammaln_time=1.2639617919921876e-06, factorial_log_time=2.2745561599731444e-06
n=168, gammaln_time=1.19215726852417e-06, factorial_log_time=2.380228042602539e-06
n=169, gammaln_time=1.4608263969421387e-06, factorial_log_time=2.2837090492248533e-06
n=170, gammaln_time=1.2564277648925782e-06, factorial_log_time=2.2177910804748535e-06
n=171, gammaln_time=1.2816047668457032e-06, factorial_log_time=2.459757328033447e-06
n=172, gammaln_time=1.2149405479431153e-06, factorial_log_time=3.720910549163818e-06
n=173, gammaln_time=2.583160400390625e-06, factorial_log_time=5.329210758209228e-06
n=174, gammaln_time=2.5930190086364744e-06, factorial_log_time=4.998750686645508e-06
精度は16桁あるのでfloat型の限界の精度に到達していますが、処理速度については、n
が170あたりでは、scipyで計算した方が2倍程度高速であることがわかります。より大きなn
で計算する場合はその差が開くことが予想されます。
まとめ
記事をまとめると以下のようになります。
-
階乗の計算の場合
- 精度重視: math.factorial()
- 速度重視: math.gamma()
- 多次元配列: scipy.special.gamma(), scipy.special.factorial()
- 大きな数(n>170): math.factorial()
-
対数の階乗の計算の場合
- 精度重視: math.log(math.factorial())
- 速度重視: math.log(math.gamma())
- 多次元配列: scipy.special.gammaln()
- 大きな数(n>170): scipy.special.gammaln(), math.log(math.factorial())
参考文献