$\def\bm{\boldsymbol}$
概要
- オンラインで開催している『モデルベース深層学習と深層展開』読み会で得られた知見や気づきをメモしていく
- ついでに、中身の理解がてらJuliaサンプルコードをPythonに書き直したコードを晒していく
- 自動微分ライブラリにはJAXを使用する
第7回
大まかな内容
- 二値制約二次計画問題の解法に深層展開を適用した例について
- DU-BQP法
- スパース信号再構成の問題の解法の近接勾配法に深層展開を適用した例について
- DU-ISTA法
議論になったこと
4.2節冒頭 あたり
- 組み合わせ最適化=離散最適化
- ”NP困難”とかは、効率的な解法があるかないかを規準としたような問題のクラス分けの一種
4.2.1節 あたり
- $C\equiv \lbrace-1,1\rbrace^n$に直接射影するのがうまくいかないのは何故?
- sign関数が微分不可だから勾配法と組み合わせるとうまくいかなくなるとか?
4.3.2節 あたり
- L-平滑性ってなんだっけ?
- ->P72で出てきた凸関数の性質。連続的に微分可能で、勾配がリプシツ定数Lでリプシツ連続な事
- 近接写像とは?
- ->式(4.58)にあるような、引数のベクトルに近いかつ、ある関数の値を小さくするベクトルへの写像
- 式(4.58)の解ってどう出すの?
- ->計算できるかどうかは$h()$がどんな関数かに依存している
- 参考
- 今回は、Lasso正則化関数に対して式(4.58)を解くとソフトしきい値関数が出てくるみたいなかんじっぽい
- ->計算できるかどうかは$h()$がどんな関数かに依存している
- P72の射影勾配法はこの近接勾配法の特殊形?
個人的な気づきなど
- 式(4.69)と式(4.67)を見比べると、LISTAはオリジナルISTAの観測方程式由来の$\bm{A}$をデータから推定するモデルフリーの手法になっている
以下8.31追記
- 近接写像については、参加メンバに教えてもらった動画が分かりやすかった
- 改めて教科書の文脈を整理すると
- 元信号が疎ベクトルという仮定を利用して、再構成問題をLasso最適化問題に落とし込む
- →しかし、L1ノルムの項は微分不可なので、そのままでは勾配法ベースの最適化ができない
- →これに対する一つの対処法が近接勾配法(教科書P100)
- 目的関数を微分可能項と不可項に分けて、微分可能項に関する勾配降下後に微分不可項に関する近接写像を使って勾配降下”っぽい”ことをするという方法
- →近接写像はその定義内に最適化問題があるので、いつも自明ではないが、L1には解が求まっていて、それがソフトしきい値関数
- 直感的には、ゼロ付近の値をゼロに、それ以外には少しゼロ側に寄るバイアスを与えるような式。これは元の最適化問題におけるL1ノルム項の役割(疎ベクトルへの正則化)に相当する事がわかる。
- ※少しややこしいのは深層展開を使うときの教科書の例では、近接写像内のパラメタの学習のために近接写像を含む計算のバックプロパゲーションを行うが、これは元の最適化計算で問題になった微分可能性の話とはまた違う話(なので、ごっちゃにしないように注意)
- 改めて教科書の文脈を整理すると
プログラムでの理解
問題設定
-
非ゼロ要素数が$n$に対し十分に小さいスパース信号$\bm{x}\in \mathbb{R}^n$がある
-
以下の線形観測ベクトル$\bm{y}\in\mathbb{R}^m$と観測行列$\bm{A}\in\mathbb{R}^{m\times n}$があり$m<n$である
- $\bm{y}=\bm{A}\bm{x}+\bm{w}$
- $\bm{w}\in\mathbb{R}^m$は雑音ベクトル
- $\bm{y}=\bm{A}\bm{x}+\bm{w}$
-
このとき、$\bm{A}$と$\bm{y}$を用いて$\bm{x}$を復元する
- $m<n$より$\bm{x}$のスパース性を考慮しなければ劣決定性問題になっている(連立方程式の数が変数より少ない)
-
最適化問題としては以下のLasso問題に帰着できる
- $\arg_{\bm{x}} \min \left(\frac{1}{2}||\bm{y}-\bm{Ax}||^2_2+\lambda||\bm{x}||_1 \right)$
- $\lambda$は正則化係数
- $\arg_{\bm{x}} \min \left(\frac{1}{2}||\bm{y}-\bm{Ax}||^2_2+\lambda||\bm{x}||_1 \right)$
-
ISTAはこのLasso問題を解くための反復更新アルゴリズム
- 適当な初期値から始めて、以下のように反復更新を繰り返す
$\bm{z}^{(t)}=\bm{x}^{(t)}+\gamma \nabla f$
$\bm{x}^{(t+1)}=S_\lambda(\bm{z}^{(t)})$ - ここで$f=||\bm{y}-\bm{Ax}||^2_2$で、$\nabla f=\bm{A}^{\top}(\bm{Ax}^{(t)}-\bm{y})$
- 適当な初期値から始めて、以下のように反復更新を繰り返す
-
深層展開を組み合わせたDU-ISTAは$\gamma$と$\lambda$を学習によって獲得しISTAの収束速度を加速する
-
問題としては以前に扱ったものとかなり似ている
Pythonで実装
- やってることの多くは、これまでと同じなので、説明は端折りがち
必要ライブラリインポート
import numpy as np
import numpy.linalg as LA
import jax
import jax.numpy as jnp
from jax.example_libraries import optimizers
from tqdm.notebook import trange
from functools import partial
問題設定
n = 128
m = 64
sigma = 0.1
p = 0.05
num_itr = 40
A = np.random.randn(m, n)
eig, _ = LA.eig(A.T @ A)
gamma_opt = float(1./max(eig).real)
A = jnp.array(A)
- 推定対象のベクトルはベルヌーイ-ガウス分布で生成される(ある要素が比ゼロになるか否かがベルヌーイ分布で決定され、比ゼロの場合は値がガウス分布で決まる)が、ベルヌーイ分布の期待値パラメータを$p=0.05$として、疎なベクトルを生成する
学習データの生成関数の定義
def mini_batch(K):
seq = np.random.randn(n, K)
support = np.random.binomial(1,p,size=(n, K))
x = jnp.array(seq * support)
y = A @ x + sigma * jnp.array(np.random.randn(m, K))
return y, x
- ベルヌーイ-ガウス分布で疎ベクトルのバッチを生成
DU-ISTAの反復関数の定義
def softshrink(x, lam):
return jnp.sign(x) * jnp.max(jnp.append(jnp.abs(x) - lam, 0))
vec_softshrink = jax.jit(jax.vmap(softshrink, in_axes=[-1, None], out_axes=-1))
@partial(jax.jit, static_argnums=0)
def DU_ISTA(max_itr, y, eta, mu):
x = jnp.zeros((n, K))
for i in range(max_itr):
x -= eta[i] * A.T@(A@x -y)
x = vec_softshrink(x.reshape(1, -1), mu[i])
x = x.reshape((n, K)).real
return x
- 射影関数の書き方がよくないのか、ちょっと冗長なコードになってる?
誤差関数の定義
@jax.jit
def get_dot(x):
return x @ x.T
batch_get_dot = jax.vmap(get_dot, in_axes=-1, out_axes=-1)
@partial(jax.jit, static_argnums=1)
def loss(x_org, max_itr, y, eta, mu):
x_hat = DU_ISTA(max_itr, y, eta, mu)
return jnp.sum(batch_get_dot(x_org - x_hat))/K
- 前回も書いたが、別に
jax.numpy.linalg.norm()
等を使っても良い
学習
opt_init1, opt_update1, get_params1 = optimizers.adam(adam_lr)
opt_init2, opt_update2, get_params2 = optimizers.adam(adam_lr)
@partial(jax.jit, static_argnums=1)
def step(x_org, max_itr, y, step_num, opt_state1, opt_state2):
tmp_eta = get_params1(opt_state1)
tmp_mu = get_params2(opt_state2)
value, grads = jax.value_and_grad(loss, argnums=-2)(x_org, max_itr, y, tmp_eta, tmp_mu)
new_opt_state1 = opt_update1(step_num, grads, opt_state1)
value, grads = jax.value_and_grad(loss, argnums=-1)(x_org, max_itr, y, tmp_eta, tmp_mu)
new_opt_state2 = opt_update2(step_num, grads, opt_state2)
return value, new_opt_state1, new_opt_state2
def train(eta, mu):
opt_state1 = opt_init1(eta)
opt_state2 = opt_init2(mu)
for itr in trange(num_itr, leave=False):
for i in range(max_inner):
y, x_org = mini_batch(K)
value, opt_state1, opt_state2 = step(x_org, itr+1, y, i, opt_state1, opt_state2)
print("\r"+"\rloss:{}".format(value), end=" ")
return get_params1(opt_state1), get_params2(opt_state2)
K = 200
adam_lr = 5e-5
max_inner = 20
eta_init = gamma_opt*jnp.ones(num_itr)
mu_init = gamma_opt*jnp.ones(num_itr)
eta_trained, mu_trained = train(eta_init, mu_init)
- パラメタを明示的に分けて学習させた
評価
def comp_mse_DUGD(max_itr, eta, mu):
y, x_org = mini_batch(K)
return float(loss(x_org, max_itr, y, eta, mu))
DUGD_mse = [comp_mse_DUGD(i+1, eta_trained, mu_trained) for i in range(num_itr)]
結果
MSEの比較
信号復元の比較
- DU版のほうが、少ない反復回数で良い復元が得られる事がわかる
その他
- 特になし
バックナンバー
- 『モデルベース深層学習と深層展開』読み会 レポート(開催前準備編)
- 『モデルベース深層学習と深層展開』読み会レポート#0
- 『モデルベース深層学習と深層展開』読み会レポート#1
- 『モデルベース深層学習と深層展開』読み会レポート#2
- 『モデルベース深層学習と深層展開』読み会レポート#3
- 『モデルベース深層学習と深層展開』読み会レポート#4
- 『モデルベース深層学習と深層展開』読み会レポート#5
- 『モデルベース深層学習と深層展開』読み会レポート#6