ウィナーフィルタをアフィン射影法(APA1)で解く
数式2は
\begin{align}
& \min_{\boldsymbol w_k} \|\boldsymbol w_k - \boldsymbol w_{k-1}\|^2 \\
& \text{subject to } \boldsymbol w_k^H \boldsymbol U_k = \boldsymbol d_k \\
\\
{\boldsymbol d}_k &= [d_{k-L_s+1}, \cdots, d_{k}] \\
{\boldsymbol U}_k &= [ {\boldsymbol u}_{k-L_s+1}, \cdots, {\boldsymbol u} _ {k} ] \\
{\boldsymbol u}_k &= [ u_k, \cdots, u_{k-K+1} ] \\
\end{align}
を解いてこう。
\begin{align}
\boldsymbol \xi_k &= \boldsymbol d_k - \boldsymbol w^H_{k-1} \boldsymbol U_k \\
\Delta \boldsymbol w &= \boldsymbol U_k ( \boldsymbol U_k^H \boldsymbol U_k )^{-1} \boldsymbol \xi^H_k\\
\boldsymbol w_k &= \boldsymbol w_{k-1} + \mu \boldsymbol U_k ( \boldsymbol U_k^H \boldsymbol U_k )^{-1} \boldsymbol \xi^H_k\\
\end{align}
ただし,逆行列部分の安定性を確保するため,下記のようにする。
\begin{align}
\boldsymbol w_k &= \boldsymbol w_{k-1} + \mu \boldsymbol U_k ( \alpha \boldsymbol I + \boldsymbol U_k^H \boldsymbol U_k )^{-1} \boldsymbol \xi^H_k\\
\end{align}
Python で書くとこう。なお,計算が重いので$n = shift \times i$として,下記のように$shift$ずつ計算するようにしている。
\begin{align}
\boldsymbol w_{last} = \boldsymbol 0 \\
\text{For } i = 0, 1, \cdots & \\
n &= L_s - shift \times i \\
\boldsymbol dk &= [d_{n-L_s}, \cdots, d_{n-1}] \\
\boldsymbol Uk &=
\begin{bmatrix}
u_{n-L_s} & u_{n-L_s+1} & \cdots & u_{n-L_s+K-1} & \cdots & u_{n-1} & u_{n} \\
u_{n-L_s-1} & u_{n-L_s} & \cdots & u_{n-L_s+K-2} & \cdots & u_{n-2} & u_{n-1} \\
\vdots & & \ddots & \vdots & & \vdots & \vdots \\
u_{n-L_s-K+1} & u_{n-L_s+1-K+1} & \cdots & u_{n-L_s} & \cdots & u_{n-K} & u_{n-K+1} \\
\end{bmatrix} \\
\\
\boldsymbol\xi k[i] &= \boldsymbol dk - {\boldsymbol w_{last}}^H \, \boldsymbol Uk \\
\boldsymbol w k[i] &= \boldsymbol w_{last} + \mu \, \boldsymbol Uk \, ( \alpha \boldsymbol I + {\boldsymbol Uk}^H \boldsymbol Uk )^{-1} \, {\boldsymbol \xi k[i]}^H\\
{\boldsymbol w_{last}} &= \boldsymbol w k[i] \\
\end{align}
# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
import scipy.signal as sg
from mpl_toolkits.mplot3d import Axes3D
z = np.arange(0, 512)
h0 = 1 / (1 + np.exp(0.0001 * (z - 128) ** 2))
# h0=1/(1+np.exp(0.03*z))
Wref = h0 / np.sum(h0) * 2
plt.plot(Wref)
plt.title("$w_{ref}$")
plt.show()
# x = np.linspace(0, 1, 4096)
x = np.linspace(0, 2, 8192)
# y1 = np.sin(2 * np.pi * x) + 0.2 * np.sin(np.pi * 20 * (x + 0.1))
y1 = np.sin(2 * np.pi * x) + 0.2*np.sin(np.pi * 20 * x + 0.1 * np.pi)
y2 = sg.lfilter(Wref, [1], y1)
μ = 0.5
α = 0.00001
Ls = 512
K = 512
shift = 128
N = np.append(np.arange(Ls, len(x), shift), len(x))
yk = np.zeros(len(x))
last_wk = np.zeros(K)
wk = np.zeros((len(N), K))
ξk = np.zeros((len(N), Ls))
for i, n in enumerate(N):
dk = y2[n-Ls:n]
#Uk = np.array([[ y1[n - k - l + 1] if n - k - l + 1 >= 0 else 0 for l in range(Ls, 0, -1)] for k in range(0, K)])
Uk = np.zeros((K, Ls))
for k in range(K):
s1 = 0
s2 = n-Ls-k
if s2 < 0:
s1 = -s2
s2 = 0
if n-k > 0:
Uk[k, s1:Ls] = y1[s2:n-k]
ξk[i] = dk - last_wk.T @ Uk
wk[i] = last_wk + μ * Uk @ np.linalg.inv(α * np.eye(Ls) + Uk.T @ Uk) @ np.conj(ξk[i].T)
yk[n-Ls:n] = np.conj(wk[i].T) @ Uk
last_wk = wk[i]
plt.plot(N, np.sqrt(np.sum(ξk ** 2, axis=1)))
plt.xlim(0)
plt.legend(["$\|ξ_k\|$"])
plt.show()
plt.plot(y2)
plt.plot(yk)
plt.xlim(0)
plt.legend(["$y2$", "$yk$"])
plt.show()
plt.plot(yk-y2)
plt.xlim(0)
plt.legend(["$yk-y2$"])
plt.show()
err = np.sum(np.abs(wk - np.ones((wk.shape[0],1)) * Wref), axis=1) / np.sum(np.abs(Wref))
plt.plot(N, err)
plt.xlim(0)
plt.legend(['$\\frac{ \| w_k-w_{ref} \| }{ \| w_{ref} \| }$'])
plt.show()
XX, YY = np.meshgrid(np.arange(wk.shape[1]), N)
plt.figure().add_subplot(111, projection='3d').plot_wireframe(XX, YY, wk)
plt.title("$wk$")
plt.show()