作成しているコードのエラー文がわからない
Q&A
Closed
pythonの基本的な問題だと思うのですが、圧縮センシングにおける再構成手法の一つAMPを作成している所、以下のようなエラー文が出てしまい自分で解決ができない状態です。なにか分かる方がいれば教えていただきたいです。
発生している問題・エラー
該当するソースコード
class ISTA:
def __init__(self, A, x, sigma):
self.A=A
self.M, self.N=A.shape
self.x=x
self.sigma=sigma
Ax=np.dot(A, x)
n=np.random.normal(0, sigma**2, (self.M, 1))
self.y=Ax+n
self.s=np.zeros((self.N, 1))
self.mse=np.array([None])
self.AT=np.transpose(A)
def estimate(self, T=20, tau=0.5):
L=self.__set__lipchitz()
ganma=1/(tau*L)
for _ in range(T):
r=self._update_r()
w=self._update_w(ganma, r)
self.s=self._update_s(w, 1/L)
self.mse=self._add_mse()
def __set__lipchitz(self):
L=np.linalg.norm(self.AT.dot(self.A),ord=2)/(0.5)
return L
def _update_r(self):
return self.y - np.dot(self.A, self.s)
def _update_w(self, ganma, r):
return self.s + ganma * np.dot(self.AT, r)
def _update_s(self, w, thre):
return soft_threshold(w, thre)
def _add_mse(self):
mse = np.linalg.norm(self.s - self.x)**2 / self.N
self.mse=np.append(self.mse, mse)
return self.mse
def result_ISTA(self):
plt.plot(self.x.real)
plt.plot(self.s.real,)
plt.grid()
def result_ISTA_MSE(self):
MSE=[]
MSE=np.append(MSE, self.mse)
Figure, ax = plt.subplots()
plt.grid()
plt.yscale('log')
plt.plot(MSE)
ax.set_ylim([0.0001,10])
plt.xticks(np.arange(1, T+1, step=1))
plt.xlabel("iteration")
plt.ylabel("MSE")
class AMP(ISTA):
def __intit__(self, A, x, sigma):
super().__init__(A, x, sigma)
self.tau=np.array([None])
def estimate(self, T=20):
Onsager=np.zeros((self.M, 1))
for _ in range(T):
r=self._update_r()
w=self._update_w(r+Onsager)
tau=self._update_tau(r+Onsager)
self.s=self._update_s(w, tau)
Onsager=np.sum(self.s != 0)/self.M*(r+Onsager)
self._add_mse
def _update_w(self, r):
return self.s+np.dot(self.AT, r)
def _update_tau(self, r):
return self.tau.append((np.linalg.norm(r)**2)/self.M)
def _update_s(self, w, tau):
return soft_threshold(w, tau**0.5)
def result_AMP(self):
plt.plot(self.x.real)
plt.plot(self.s.real)
plt.grid()
def result_AMP_MSE(self):
MSE=[]
MSE=np.append(MSE, self.mse)
print(MSE)
Figure, ax = plt.subplots()#グラフオブジェクトを生成
plt.grid()
plt.yscale('log')
plt.plot(MSE)
ax.set_ylim([0.0001,10])
plt.xticks(np.arange(1, T+1, step=1))
plt.xlabel("iteration")
plt.ylabel("MSE")
自分で試したこと
エラー文の意味は、tauが存在していないと言われていると解釈したんですが、自分ではなぜこのようなエラー文が出たのか分からない
0