LoginSignup
15
10

More than 5 years have passed since last update.

Pythonでオンライン線形回帰(ロバスト推定編)

Posted at

はじめに

下記のロバスト推定版です。

数式

Biweight推定法を多少アレンジして使います。
係数のオンライン推定については上記の前記事を参照のこと。

平均値のオンライン推定

\bar{x}_n = (1-\alpha) \bar{x}_{n-1} + \alpha x_n

Biweight推定法

  • $W$は誤差の許容範囲
  • すでに推定した $a,b$ から誤差 $d$ を求めて、誤差が大きいほど重みを小さくする
d = y-(ax+b)\\
w(d) = \left\{\begin{array}{l}
\left[ 1 - \left( \frac{d}{W} \right)^2 \right]^2 & \left(|d| \le W \right) \\
0 & \left(W < |d| \right)
\end{array}\right.

解説的な

10回に1度くらい偏差10倍の外れ値を入れる

outliers = [0,0,0,0,0,0,0,0,0,1]

...

y = x + 0.05 * sp.random.normal() + outliers[sp.random.randint(10)] * 0.5 * sp.random.normal()

はじめ30回は普通通り、それ以降はbiweightをかける

W = 0.1
def biweight(d):
    return ( 1 - (d/W) ** 2 ) ** 2 if abs(d/W) < 1 else 0

...

if count[0] < 30:
    alpha = 0.1
else:
    d = y - (a * x + b)
    alpha = weight(d) * 0.1

ソースコード

(やっつけなのできたないです)

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import matplotlib.pyplot as plt
import matplotlib.animation as animation
import scipy as sp
sp.seterr(divide='ignore', invalid='ignore')

def mean(old, new, alpha):
    return new if sp.isnan(old) else ( 1.0 - alpha ) * old + alpha * new

W = 0.1
def weight(d):
    return ( 1 - (d/W) ** 2 ) ** 2 if abs(d/W) < 1 else 0

def plot(fig):
    a = sp.array([sp.nan])
    b = sp.array([sp.nan])

    xyhist = sp.ones([100, 2]) * sp.nan
    mean_x0  = sp.array([sp.nan])
    mean_y0  = sp.array([sp.nan])
    mean_x20 = sp.array([sp.nan])
    mean_xy0 = sp.array([sp.nan])

    mean_x  = sp.array([sp.nan])
    mean_y  = sp.array([sp.nan])
    mean_x2 = sp.array([sp.nan])
    mean_xy = sp.array([sp.nan])

    ax = fig.gca()
    ax.hold(True)
    ax.grid(True)
    ax.set_xlim([0, 1.0])
    ax.set_ylim([0, 1.0])

    xyscat = ax.scatter([],[], c='black', s=10, alpha=0.4)
    approx0 = ax.add_line(plt.Line2D([], [], color='r'))
    approx = ax.add_line(plt.Line2D([], [], color='b'))

    outliers = [0,0,0,0,0,0,0,0,0,1]

    count = [0]

    def inner(i):
        x = sp.random.rand()
        y = x + 0.05 * sp.random.normal() + outliers[sp.random.randint(10)] * 0.5 * sp.random.normal()

        count[0] += 1
        if count[0] < 30:
            alpha = 0.1
        else:
            d = y - (a * x + b)
            alpha = biweight(d) * 0.1

        xyhist[:-1, :] = xyhist[1:, :]
        xyhist[-1, 0] = x
        xyhist[-1, 1] = y

        mean_x0[:]  = mean( mean_x0,  x,      0.1 )
        mean_y0[:]  = mean( mean_y0,  y,      0.1 )
        mean_xy0[:] = mean( mean_xy0, x *  y, 0.1 )
        mean_x20[:] = mean( mean_x20, x ** 2, 0.1 )

        mean_x[:]  = mean( mean_x,  x,      alpha )
        mean_y[:]  = mean( mean_y,  y,      alpha )
        mean_xy[:] = mean( mean_xy, x *  y, alpha )
        mean_x2[:] = mean( mean_x2, x ** 2, alpha )

        a0 = ( mean_xy0 - mean_x0 * mean_y0 ) / ( mean_x20 - mean_x0 ** 2 )
        b0 = mean_y0 - a0 * mean_x0

        a[:] = ( mean_xy - mean_x * mean_y ) / ( mean_x2 - mean_x ** 2 )
        b[:] = mean_y - a * mean_x

        ax.title.set_text('y = %.3fx %+.3f' % (a, b))
        xyscat.set_offsets(xyhist)
        approx.set_data([0, 1], [b, a*1+b])
        approx0.set_data([0, 1], [b0, a0*1+b0])
        plt.draw()

    return inner


fig = plt.figure()
ani = animation.FuncAnimation(fig, plot(fig), interval=100, frames=300)
ani.save('result2.gif', writer='imagemagick')

結果

biweight(青)の方がある程度安定するみたいですね。

result2.gif

15
10
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
15
10