はじめに
Pythonで線形回帰 $y=ax+b$ の係数 $a,b$ を、オンライン推定します。
数式
最小二乗法
a = \frac{\overline{xy} - \bar{x}\bar{y}}{\overline{x^2} - \bar{x}^2}
, \hspace{2em} b = \bar{y} - a \bar{x}
平均値のオンライン推定
\bar{x}_n = \alpha \bar{x}_{n-1} + (1-\alpha)x_n
ソースコード
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import scipy as sp
sp.seterr(divide='ignore', invalid='ignore')
alpha = 0.9
def mean(old, new):
return new if sp.isnan(old) else alpha * old + ( 1.0 - alpha ) * new
def plot(fig):
xyhist = sp.ones([100, 2]) * sp.nan
mean_x = sp.array([sp.nan])
mean_y = sp.array([sp.nan])
mean_x2 = sp.array([sp.nan])
mean_xy = sp.array([sp.nan])
ax = fig.gca()
ax.hold(True)
ax.grid(True)
ax.set_xlim([0, 1.0])
ax.set_ylim([0, 1.0])
xyscat = ax.scatter([],[], c='black', s=10, alpha=0.4)
approx = ax.add_line(plt.Line2D([], [], color='r'))
def inner(i):
x = sp.random.rand()
y = x + 0.05 * sp.random.normal()
xyhist[:-1, :] = xyhist[1:, :]
xyhist[-1, 0] = x
xyhist[-1, 1] = y
mean_x[:] = mean( mean_x, x )
mean_y[:] = mean( mean_y, y )
mean_xy[:] = mean( mean_xy, x * y )
mean_x2[:] = mean( mean_x2, x ** 2 )
a = ( mean_xy - mean_x * mean_y ) / ( mean_x2 - mean_x ** 2 )
b = mean_y - a * mean_x
xyscat.set_offsets(xyhist)
approx.set_data([0, 1], [b, a*1+b])
ax.title.set_text('y = %.3fx %+.3f' % (a, b))
plt.draw()
return inner
fig = plt.figure()
ani = animation.FuncAnimation(fig, plot(fig), interval=300, frames=100)
ani.save('result.gif', writer='imagemagick')