3
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

平均値と分散の逐次計算(オンラインアルゴリズム)

Last updated at Posted at 2022-10-02

概要

平均値と分散の逐次計算を愚直にやると誤差が出るという話。

https://qiita.com/Ushio/items/f5630d87f55c7afa984e
上記のPython版(Welfordアルゴリズム + カハンの加算アルゴリズム)

結論

class Kahan:
    def __init__(self, s=0.0, c=0.0):
        self._sum = s
        self._c = c
    
    def add(self, x):
        y = x - self._c
        t = self._sum + y
        self._c = (t - self._sum) - y
        self._sum = t
        return self
    
    def __iadd__(self, other):
        return self.add(other)
    
    def __rsub__(self, other):
        return other - self._sum
    
    def __truediv__(self, other):
        return self._sum / other
    
    def __float__(self):
        return float(self._sum)
    
    def __str__(self):
        return str(self._sum)
    
    def __repr__(self):
        return f"{self.__class__.__name__}({self._sum}, {self._c})"
    
    def __format__(self, format_spec):
        return format(self._sum, format_spec)
    
    @property
    def dtype(self):
        return np.dtype(float)
    
    def __array__(self):
        return np.array(self._sum)

class OnlineVariance:
    def __init__(self):
        self._mean = Kahan()
        self._M2 = Kahan()
        self._n = 0
    
    def addSample(self, value):
        self._n += 1
        delta = value - self._mean
        self._mean += delta / self._n
        delta2 = value - self._mean
        self._M2 += delta * delta2
    
    @property
    def variance(self):
        return self._M2 / self._n
    
    @property
    def avarage(self):
        return self._mean

使い方

ov = OnlineVariance()
for x in ary:
    ov.addSample(x)

print(ov.avarage, ov.variance)

詳細

愚直な方法

import numpy as np

loc, scale = np.random.rand() * 10, np.random.rand() * 10
size = np.random.randint(5000, 100000)
ary = np.random.normal(loc, scale, size)

# オンラインで覚えておく値
n = 0 # 個数
m = 0 # 平均
v = 0 # 分散
for x in ary:
    # オンラインで計算
    old_m = m
    m = (n * m + x) / (n + 1)
    v = (n * (v + old_m ** 2) + x ** 2) / (n + 1) - m ** 2
    n += 1

print(ary.mean(), ary.var(), sep="\t")
print(m, v, sep="\t")
print(np.allclose(m, ary.mean()), np.allclose(v, ary.var()))

若干誤差が発生。numpy.allclose では誤差の範囲内。

9.620947870954014	57.53665170250013 # 一括
9.620947870953962	57.53665170249734 # 逐次(愚直)
True True 

Welfordアルゴリズム

Welfordアルゴリズムを適用。ただし、カハンの加算アルゴリズムなし

_n = 0
_M2 = 0
_mean = 0

for x in ary:
    _n += 1
    delta = x - _mean
    _mean += delta / _n
    delta2 = x - _mean
    _M2 += delta * delta2
    
print(ary.mean(), ary.var(), sep="\t")
print(_mean, _M2 / _n, sep="\t")
print(np.allclose(_mean, ary.mean()), np.allclose(_M2 / _n, ary.var()))

誤差減少。でも少し違う

9.620947870954014	57.53665170250013 # 一括
9.620947870953962	57.53665170249734 # 逐次(愚直)
9.62094787095407	57.5366517025     # Welfordアルゴリズム
True True

Welfordアルゴリズム + カハンの加算アルゴリズム

カハンの加算アルゴリズムはクラス化。dtype や __array__ 等なくてもいいのあるけど、
numpy.allclose を通すのに適当に追加。

_n = 0
_M2 = 0
_mean = 0

for x in ary:
    _n += 1
    delta = x - _mean
    _mean += delta / _n
    delta2 = x - _mean
    _M2 += delta * delta2
    
class Kahan:
    def __init__(self, s=0.0, c=0.0):
        self._sum = s
        self._c = c
    
    def add(self, x):
        y = x - self._c
        t = self._sum + y
        self._c = (t - self._sum) - y
        self._sum = t
        return self
    
    def __iadd__(self, other):
        return self.add(other)
    
    def __rsub__(self, other):
        return other - self._sum
    
    def __truediv__(self, other):
        return self._sum / other
    
    def __float__(self):
        return float(self._sum)
    
    def __str__(self):
        return str(self._sum)
    
    def __repr__(self):
        return f"{self.__class__.__name__}({self._sum}, {self._c})"
    
    @property
    def dtype(self):
        return np.dtype(float)
    
    def __array__(self):
        return np.array(self._sum)
    
_n = 0
_M2 = Kahan()
_mean = Kahan()

for x in ary:
    _n += 1
    delta = x - _mean
    _mean += delta / _n
    delta2 = x - _mean
    _M2 += delta * delta2  

print(ary.mean(), ary.var(), sep="\t")
print(_mean, _M2 / _n, sep="\t")
print(np.allclose(_mean, ary.mean()), np.allclose(_M2 / _n, ary.var()))

誤差はかなり減った。入力サンプルによっては最後の桁まで一致することもある。

9.620947870954014	57.53665170250013 # 一括
9.620947870954016	57.53665170250013 # Welfordアルゴリズム + カハンの加算アルゴリズム
9.62094787095407	57.5366517025     # Welfordアルゴリズム
9.620947870953962	57.53665170249734 # 逐次(愚直)

True True

そういえば、平均値と分散を逐次計算すると誤差大きかったよなで、
ちょっと調べてみたら、意外に根が深かった。

マージするの分散は、一括時と値が違ったので、そのまま使うのは危険かも。
平均と同じように何か重みづけとかいる気がするけど、並列で計算することが予定にないので、詳しくは追ってないです。

3
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?