3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

ウィナー・ホッフ方程式をPythonで解いてみた,まとめ

Posted at

アルゴリズム比較

下記のアルゴリズムを実装

  • LMS: least-mean-square
  • NLMS: normalized-least-mean-square
  • APA: affine projection algorithm
  • RLS: recursive least squares

どのアルゴリズムを使うかは,入力データと必要な精度・処理時間に依存するのかと。
パラメータを変えると精度と時間がすごく変わる。入力データは変えていないけど,
入力データにも依存するんだろうな。。。

テスト環境

  • 入力データ長:4096サンプル
  • タップ数:512
  • x: 入力データ(ex. 再生する信号)
  • z: xに畳み込んだデータ(ex. エコー経路のインパルス応答を畳み込んだ信号)
  • d: 出力する想定データ(ex. x にノイズを重畳した信号,マイク信号)

wiener_filter_inp.png

実行結果

処理時間

wiener_filter_time.png

16kHz なら0.256s以下でなければ,リアルタイム処理は不可。LMS/NLMSじゃないとこのままではリアルタイム処理は難しいか。
もう少しサンプル数増やさないと計測誤差もあるだろうけど。(CPU: Core i7-6700 3.4GHz)

推定結果と誤差

wiener_filter_res.png

LMS/NLMSはランダムノイズにずっと振り回されているけど,APA/RLSは急に振り回されていない。

wiener_filter_xi.png

推定後の信号から,ノイズ重畳前の信号を減算すると一目瞭然。

アルゴリズムの実装とテストコード

アルゴリズムの実装

  • 各アルゴリズムをインスタンス化する際にパラメータを指定
  • 処理はprocess/process_pair で行う。
    PyAudio等 で収音した音声に対しても適用できるように入力をイテレータとしてみた。
  • reporter は途中経過を観察のためのログ用。
wiener_filter.py
# -*- coding: utf-8 -*-

from abc import ABCMeta, abstractmethod
import numpy as np
from more_itertools import chunked
from scipy.linalg import toeplitz
from time import perf_counter

class BaseReporter(metaclass=ABCMeta):
    def reset(self):
        pass
    
    def report(self, n, y, xi, w):
        pass

class SamplingLogReporter(BaseReporter):
    def __init__(self, interval = 100, show_counter=False):
        self.interval = interval
        self.show_counter = show_counter
        self.reset()
    
    def reset(self):
        self.count  = 0
        self.start  = perf_counter()
        
        self.log_n  = []
        self.log_w  = []
        self.log_xi = []
        self.log_y  = []

    def report(self, n, y, xi, w):
        self.count += 1
        if self.count % self.interval == 0:
            if self.show_counter:
                print("report", n * (self.count - 1), perf_counter() - self.start)
            self.log_n.append(n * (self.count - 1))
            self.log_w.append(w)
            self.log_xi.append(xi)
            self.log_y.append(y)


class AbstractWienerFilter(metaclass=ABCMeta):
    def __init__(self, reporter=BaseReporter(), **args):
        self.reporter = reporter
        self.reset(**args)
    
    def report(self, n, y, xi, w):
        self.reporter.report(n, y, xi, w)
    
    def reset(self, **args):
        self.reporter.reset()
        self._reset(**args)
    
    @abstractmethod
    def _reset(self, **args):
        pass
    
    def process(self, iter_inp, iter_res):
        return self.process_pair(zip(iter_inp, iter_res))
    
    @abstractmethod
    def process_pair(self, itr):
        pass


class LMS(AbstractWienerFilter):
    def _reset(self, MU=0.001, K=512):
        self.K      = K
        self.MU     = MU
        
        self.w = np.zeros(self.K)
        self.u = np.zeros(self.K)

    def process_pair(self, itr):
        MU = self.MU
        for inp, res in itr:
            u = np.r_[inp, self.u[:-1]]
            d = res
            last_w = self.w
            
            xi = d - last_w.T @ u
            w = last_w + MU * u * np.conj(xi)
            y = u.T @ w
            
            self.w = w
            self.u = u
            self.report(1, y, xi, w)
            yield y
        

class NLMS(AbstractWienerFilter):
    def _reset(self, ALPHA=0.00001, MU=0.5, K=512):
        self.K      = K
        self.MU     = MU
        self.ALPHA  = ALPHA        
        
        self.w = np.zeros(self.K)
        self.u = np.zeros(self.K)

    def process_pair(self, itr):
        MU    = self.MU
        ALPHA = self.ALPHA
        for inp, res in itr:
            u = np.r_[inp, self.u[:-1]]
            d = res
            last_w = self.w
            
            xi = d - last_w.T @ u
            w = last_w + MU * u * np.conj(xi) / (ALPHA + (np.linalg.norm(u) ** 2))
            y = u.T @ w
            
            self.w = w
            self.u = u
            self.report(1, y, xi, w)
            yield y

class APA(AbstractWienerFilter):
    def _reset(self, ALPHA=0.00001, MU=0.5, LS=512, K=512, SHIFT=128):
        assert LS >= SHIFT
        assert K  >= SHIFT
        
        self.K      = K
        self.MU     = MU
        self.ALPHA  = ALPHA
        self.LS     = LS
        self.SHIFT  = SHIFT
        
        self.w = np.zeros(self.K)
        self.U = np.zeros((self.K, 0))
        self.c = np.zeros(self.K)
        self.d = np.array([])

    def process_pair(self, itr):
        MU    = self.MU
        ALPHA = self.ALPHA
        LS    = self.LS
        SHIFT = self.SHIFT
        
        for chunk in chunked(itr, SHIFT):
            chunk = np.r_[chunk]
            inp   = chunk[:, 0]
            res   = chunk[:, 1]

            self.c = np.r_[inp[0], self.c[:-1]]
            self.U = np.c_[self.U, toeplitz(self.c, inp)]
            self.d = np.r_[self.d, res]
            self.c = np.r_[inp[1:SHIFT][::-1], self.c[:-(SHIFT-1)]]

            if len(self.d) < LS:
                continue
            
            U = self.U[:, :LS]
            d = self.d[:LS]
            last_w = self.w
            
            xi = d - last_w.T @ U
            w  = last_w + MU * U @ np.linalg.inv(ALPHA * np.eye(LS) + U.T @ U) @ np.conj(xi.T)
            y  = np.conj(w.T) @ U
            
            self.w = w
            self.U = self.U[:, SHIFT:]
            self.d = self.d[SHIFT:]
            self.report(SHIFT, y, xi, w)
            
            for i in range(SHIFT):
                yield y[i]
                
class RLS(AbstractWienerFilter):
    def _reset(self, ALPHA=0.00001, PHI=1.0, K=512, CHUNK=128):
        self.K     = K
        self.ALPHA = ALPHA
        self.PHI   = PHI
        self.CHUNK = CHUNK
        
        self.c = np.zeros(self.K)
        self.w = np.zeros(self.K)
        self.P = np.eye(self.K) / self.ALPHA

    def process_pair(self, itr):
        K     = self.K
        PHI   = self.PHI
        for chunk in chunked(itr, self.CHUNK):
            size  = len(chunk) 
            chunk = np.r_[chunk]
            inp   = chunk[:, 0]
            res   = chunk[:, 1]
            
            c      = np.r_[inp[0], self.c[:-1]]
            U      = toeplitz(c, inp)
            d      = res
            self.c = np.r_[inp[1:][::-1], c][:self.K]
            
            last_P = self.P
            last_w = self.w
            
            G  = last_P @ U @ np.linalg.inv(PHI * np.eye(size) + np.conj(U).T @ last_P @ U)
            xi = d - last_w.T @ U
            w  = last_w + G @ np.conj(xi)
            P  = (np.eye(K) - G @ np.conj(U).T) @ last_P / PHI
            y = np.conj(w).T @ U
            
            self.w = w
            self.P = P
            self.report(size, y, xi, w)
            
            for i in range(size):
                yield y[i]

テストコード

グラフを作ったコード。

# -*- coding: utf-8 -*-

import wiener_filter
import numpy as np
import matplotlib.pyplot as plt
import scipy.signal as sg
from time import perf_counter

# h はエコー経路のインパルス応答
H = 1 / (1 + np.exp(0.0001 * (np.arange(0, 512) - 128) ** 2))
H = H / np.sum(H)

# t は離散時間
t  = np.linspace(0, 2, 4096)

# x は再生する信号
x = np.sin(2 * np.pi * t) + 0.2 * np.sin(np.pi * 20 * (t + 0.1))
# x = np.sin(2 * np.pi * t) + 0.2*np.sin(np.pi * 20 * t + 0.1 * np.pi)

# z はエコー信号
z = sg.lfilter(H, [1], x)

# n はノイズ
n = np.random.randn(len(t)) * 0.1

# d は入力信号
d = z + n

plt.plot(t, d, label="d")
plt.plot(t, x, label="x")
plt.plot(t, z, label="z")
plt.legend()
plt.show()

def measure(alg, x, d):
    start = perf_counter()
    ret = list(alg.process(x, d))
    end = perf_counter()
    return (ret, end - start)

lms  = wiener_filter.LMS( reporter=wiener_filter.SamplingLogReporter())
nlms = wiener_filter.NLMS(reporter=wiener_filter.SamplingLogReporter())
apa  = wiener_filter.APA( reporter=wiener_filter.SamplingLogReporter(1))
rls  = wiener_filter.RLS( reporter=wiener_filter.SamplingLogReporter(1))

lms_y,  lms_t  = measure(lms,  x, d)
nlms_y, nlms_t = measure(nlms, x, d)
apa_y,  apa_t  = measure(apa,  x, d)
rls_y,  rls_t  = measure(rls,  x, d)

print("lms", lms_t)
print("nlms", nlms_t)
print("apa", apa_t)
print("rls", rls_t)

label = ["LMS", "NLMS", "APA", "RLS"]
colors = ['#8c564b', '#9467bd', '#e377c2', '#d62728']
rst   = [lms_t, nlms_t, apa_t, rls_t]
bar_list = plt.bar(label, rst)
for i, b in enumerate(bar_list):
    b.set_color(colors[i])

for i, val in enumerate(rst):
    plt.text(i, val, "%0.2fs" % val, horizontalalignment='center', verticalalignment='bottom' )
plt.show()

plt.figure(figsize=(10,5))
plt.subplot(221)
plt.plot(t, d,      label="d")
plt.plot(t, lms_y,  label="LMS", c=colors[0])
plt.plot(t, z,      label="z")
plt.legend()


plt.subplot(222)
plt.plot(t, d,      label="d")
plt.plot(t, nlms_y, label="NLMS", c=colors[1])
plt.plot(t, z,      label="z")
plt.legend()

plt.subplot(223)
plt.plot(t, d,      label="d")
plt.plot(t[:len(apa_y)], apa_y,  label="APA", c=colors[2])
plt.plot(t, z,      label="z")
plt.legend()

plt.subplot(224)
plt.plot(t, d,      label="d")
plt.plot(t, rls_y,  label="RLS", c=colors[3])
plt.plot(t, z,      label="z")
plt.legend()
plt.show()


plt.figure(figsize=(10,5))
plt.subplot(221)
plt.plot(t, lms_y  - z, label="LMS - z", c=colors[0])
plt.legend()
plt.subplot(222)
plt.plot(t, nlms_y - z, label="NLMS - z", c=colors[1])
plt.legend()
plt.subplot(223)
plt.plot(t, np.zeros(len(t)), color="w")
plt.plot(t[:len(apa_y)], apa_y  - z[:len(apa_y)], label="APA - z", c=colors[2])
plt.legend()
plt.subplot(224)
plt.plot(t, rls_y  - z, label="RLS - z", c=colors[3])
plt.legend()
plt.show()
3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?