LoginSignup
1
1

More than 3 years have passed since last update.

Donald Knuthの不偏分散逐次計算アルゴリズムをPythonで実装してみた

Last updated at Posted at 2020-08-25

 本稿では、優れた分散の計算アルゴリズムであるDonald Knuthのアルゴリズムを紹介します。分散の計算など普段ライブラリに頼り切りで、そのアルゴリズムなど特に考慮したことなどなかったので、これを機に知識を深められたらと思います。

不偏分散の公式による計算の問題点

 初めに、なぜこのようなアルゴリズムが必要なのかについて紹介します。不偏分散は、以下の数式で表されます。

\sigma^2 = \frac{1}{n(n-1)} \biggl(\sum_{i=1}^{n}x_i^2-\Bigl(\sum_{i=1}^{n}x_i\Bigr)^2\biggr) =\frac{1}{n-1}\Bigl(E[X^2]-E[X]^2\Bigr)

さて、ここでは2つの累積和が登場しています。1.$x$自身の累積和、および2.$x^2$の累積和です。このときの問題として、大きく分けて2つ考えられます。まず一つ目は、$x$が非常に大きな数のとき、数値計算の正確性が保証されないことです。2つ目は、サンプルが増えたときに再度大規模な計算を行わなければらならないことです。これらの課題を解決するため、次節でDonald Knuthのアルゴリズムを見ていきましょう。

Donald Knuthの逐次計算アルゴリズム

 前述した課題を解決するために、逐次的に分散を計算するように設計されたのがこのアルゴリズムです。以下にアルゴリズムを示します。$M_k$は平均を表し、$s_k$は分散の分子となります。

初期化:$M_1=x_1, S_1 = 0$
$k \geq 2$において以下の漸化式を計算

\begin{align}
M_k &= M_{k-1} + \frac{(x_k - M_{k-1})}{k}\\
S_k &= S_{k-1} + (x_k - M_{k-1})*(x_k - M_k)
\end{align}

求める不偏分散の推定量は$k \geq 2$において$s^2 = S_k/(k-1)$となります。

Pythonでの実装

先ほどのアルゴリズムが実際に動くことを確認するためにPythonで実装してみます。漸化式のため、再帰関数で実装することにします。逐次計算されている様子が見られるようにいちいちprint文を挟み確認できるようにしました。

def calc_var(x, k=0, M=0, s=0):
    print('k=', k)
    print('M=', M)
    print('s=', s)
    print('-----------------------')
    if k == 0:
        M = x[0]
        s = 0
    delta = x[k] - M
    M += delta/(k+1)
    s += delta*(x[k]-M)
    k += 1
    if k == len(x):
        return M, s/(k-1)
    else:
        return calc_var(x, k=k, M=M, s=s)

x = [3.3, 5, 7.2, 12, 4, 6, 10.3]
print(calc_var(x))

実行結果

k= 0
M= 0
s= 0
-----------------------
k= 1
M= 3.3
s= 0.0
-----------------------
k= 2
M= 4.15
s= 1.4449999999999996
-----------------------
k= 3
M= 5.166666666666667
s= 7.646666666666666
-----------------------
k= 4
M= 6.875
s= 42.6675
-----------------------
k= 5
M= 6.3
s= 49.279999999999994
-----------------------
k= 6
M= 6.25
s= 49.355
-----------------------
(6.828571428571428, 10.56904761904762)

実際にnumpyでも平均、分散を計算して比較してみます。キーワード引数ddof=1は不偏分散として計算するためのものです。

import numpy as np
x_arr = np.array(x)
print((x_arr.mean(), x_arr.var(ddof=1)))

出力結果

(6.828571428571428, 10.569047619047621)

このように逐次的に計算することができました。実際はnumpyよりもpythonの再帰のほうが早いことなどないと思うので自分で実装する機会はないと思いますが、分散の計算アルゴリズムとして優れているにも関わらずあまり知られていないため記事にさせていただきました。皆様の参考になれば幸いです。

参照リンク

Accurately computing running variance
https://www.johndcook.com/blog/standard_deviation/

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1