#ウィナーフィルタを再帰最小二乗法(RLS1)で解く
数式2は
(\boldsymbol A + \boldsymbol x \boldsymbol x^H )^{-1} = \boldsymbol A^{-1} - \frac{\boldsymbol A^{-1} \boldsymbol x \boldsymbol x^H \boldsymbol A^{-1} }{1 + \boldsymbol x^H \boldsymbol A^{-1} \boldsymbol x} \\
の補助定理を使って
\begin{align}
\hat{\boldsymbol R}_k &= \hat{\boldsymbol R}_{k-1} + \boldsymbol u_k \boldsymbol u_k^H \\
\hat{\boldsymbol r}_k &= \hat{\boldsymbol r}_{k-1} + \boldsymbol u_k d_k^* \\
\\
\boldsymbol P_k &= \hat{\boldsymbol R}_k^{-1} \\
\boldsymbol P_k &= (\hat{\boldsymbol R}_{k-1} + \boldsymbol u_k \boldsymbol u_k^H )^{-1} \\
&= \boldsymbol P_{k-1} - \frac{\boldsymbol P_{k-1} \boldsymbol u_k \boldsymbol u_k^H \boldsymbol P_{k-1} }{1 + \boldsymbol u_k^H \boldsymbol P_{k-1} \boldsymbol u_k} \\
\\
\boldsymbol g_k &= \frac{\boldsymbol P_{k-1} \boldsymbol u_k }{1 + \boldsymbol u_k^H \boldsymbol P_{k-1} \boldsymbol u_k} \\
\boldsymbol P_k &= \boldsymbol P_{k-1} - \boldsymbol g_k \boldsymbol u_k^H \boldsymbol P_{k-1} \\
\\
\boldsymbol g_k + \boldsymbol g_k \boldsymbol u_k^H \boldsymbol P_{k-1} \boldsymbol u_k &= \boldsymbol P_{k-1} \boldsymbol u_k \\
\boldsymbol g_k &= \boldsymbol P_{k-1} \boldsymbol u_k - \boldsymbol g_k \boldsymbol u_k^H \boldsymbol P_{k-1} \boldsymbol u_k \\
&= \boldsymbol P_k \boldsymbol u_k \\
\\
\boldsymbol w_k &= \boldsymbol P_k \hat{\boldsymbol r}_k \\
&= \boldsymbol P_k (\hat{\boldsymbol r}_{k-1} + \boldsymbol u_k d_k^*) \\
&= (\boldsymbol P_{k-1} - \boldsymbol g_k \boldsymbol u_k^H \boldsymbol P_{k-1}) \hat{\boldsymbol r}_{k-1} + \boldsymbol P_k \boldsymbol u_k d_k^*) \\
&= \boldsymbol w_{k-1} - \boldsymbol g_k \boldsymbol u_k^H \boldsymbol w_{k-1} + \boldsymbol g_k d_k^* \\
&= \boldsymbol w_{k-1} + \boldsymbol g_k ( d_k^* - \boldsymbol u_k^H \boldsymbol w_{k-1}) \\
&= \boldsymbol w_{k-1} + \boldsymbol g_k \xi_k^* \\
\\
\xi_k &= d_k - \boldsymbol w_{k-1}^H \boldsymbol u_k \\
\end{align}
をこう。整理すると,初期値を
\boldsymbol P_0 = \alpha^{-1} \boldsymbol I \\
\boldsymbol w_0 = \boldsymbol 0
として,
\begin{align}
\boldsymbol g_k &= \frac{\boldsymbol P_{k-1} \boldsymbol u_k }{1 + \boldsymbol u_k^H \boldsymbol P_{k-1} \boldsymbol u_k} \\
\xi_k &= d_k - \boldsymbol w_{k-1}^H \boldsymbol u_k \\
\boldsymbol w_k &= \boldsymbol w_{k-1} + \boldsymbol g_k \xi_k^* \\
\boldsymbol P_k &= (\boldsymbol I - \boldsymbol g_k \boldsymbol u_k^H ) \boldsymbol P_{k-1} \\
y_k &= \boldsymbol w_k^H \boldsymbol u_k
\end{align}
こう。Pythonのコードはこう。ただ,結構遅い。$K$を小さくするか何かしないと微妙?。忘却係数は未適用。
# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
import scipy.signal as sg
from mpl_toolkits.mplot3d import Axes3D
z = np.arange(0, 512)
h0 = 1 / (1 + np.exp(0.0001 * (z - 128) ** 2))
#h0=1/(1+np.exp(0.03*z))
Wref = h0 / np.sum(h0) * 2
plt.plot(Wref)
plt.title("$w_{ref}$")
plt.show()
#x = np.linspace(0, 1, 4096)
x = np.linspace(0, 2, 8192)
#y1 = np.sin(2 * np.pi * x) + 0.2 * np.sin(np.pi * 20 * (x + 0.1))
y1 = np.sin(2 * np.pi * x) + 0.2*np.sin(np.pi * 20 * x + 0.1 * np.pi)
y2 = sg.lfilter(Wref, [1], y1)
#μ = 1.0
α = 0.00001
#Ls = 128
K = 512
#shift = 32
#N = np.append(np.arange(Ls, len(x), shift), len(x))
N = np.arange(len(x))
yk = np.zeros(len(x))
last_Pk = np.eye(K) / α
last_wk = np.zeros(K)
wk = np.zeros((len(N), K))
ξk = np.zeros(len(N))
for n in N:
dk = y2[n]
if n < K:
uk = np.zeros(K)
uk[:n+1] = y1[n::-1]
else:
uk = y1[n:n-K:-1]
gk = last_Pk @ uk / (1 + uk.T @ last_Pk @ uk)
ξk[n] = dk - last_wk.T @ uk
wk[n] = last_wk + gk * np.conj(ξk[n])
Pk = (np.eye(K) - np.outer(gk, uk.T)) @ last_Pk
yk[n] = np.conj(wk[n].T) @ uk
last_wk = wk[n]
last_Pk = Pk
#print(n, ξk[n])
plt.plot(N, ξk)
plt.xlim(0)
plt.legend(["$\|ξ_k\|$"])
plt.show()
plt.plot(y2)
plt.plot(yk)
plt.xlim(0)
plt.legend(["$y2$", "$yk$"])
plt.show()
plt.plot(yk-y2)
plt.xlim(0)
plt.legend(["$yk-y2$"])
plt.show()
err = np.sum(np.abs(wk - np.ones((wk.shape[0],1)) * Wref), axis=1) / np.sum(np.abs(Wref))
plt.plot(N, err)
plt.xlim(0)
plt.legend(['$\\frac{ \| w_k-w_{ref} \| }{ \| w_{ref} \| }$'])
plt.show()
XX, YY = np.meshgrid(np.arange(wk.shape[1]), N)
plt.figure().add_subplot(111, projection='3d').plot_wireframe(XX, YY, wk)
plt.title("$wk$")
plt.show()