$\def\bm{\boldsymbol}$
概要
- オンラインで開催している『モデルベース深層学習と深層展開』読み会で得られた知見や気づきをメモしていく
- ついでに、中身の理解がてらJuliaサンプルコードをPythonに書き直したコードを晒していく
- 自動微分ライブラリにはJAXを使用する
第5回
大まかな内容
- 以下の最適化のトピックを、深層展開の視点を絡めて取り扱った
- 射影勾配法
- 勾配雑音を含む勾配法
- 非凸最適化問題への勾配法の適用
議論になったこと
3.4.1節 あたり
- リプシッツ性関連
- そもそも数学的に滑らかってどんな定義だっけ?
- ->右からの導関数も左からの導関数も同じ点に収束する
- それよりはリプシッツ性はかなり弱い性質、リプシッツ定数Lの大きさによっては、人間の目にはカクッとして見える場合もある
- そもそも数学的に滑らかってどんな定義だっけ?
- 劣一次収束、一次収束関連
- 1次法で反復更新をするときに、ステップ数tが増える毎に最適値と現在値の誤差がどんなオーダーで減少していくかの話
- たとえば劣一次収束は、定数かける$\frac{1}{t}$ずつ誤差が減っていく、1次収束は定数かける$\exp(-t)$ずつ減っていく。これは結構強い性質。
3.4.3節あたり
- tanh関数ではなくその近似関数を使っているのは何か理由がある?
- ->答え出ず
- ボックス制約の境界に変数が張り付いた場合、その後の反復の挙動はどうなる?
- 目的関数が凸なら、その軸の変数はその境界に張り付いたままになりそう
- 目的関数が非凸なら、その後の最適化の方法によって、また境界以外に移動する事もありそう
- DU-PGD法のサンプルコードを眺めながら
- 今回周期条件は入ってないが、パラメタ数がかなり増える事になる?
- ->勾配降下ステップと同じだけのパラメタ数になる。まあでも、一般的なDeepの学習例に比べたらかわいいもん?
- 今回周期条件は入ってないが、パラメタ数がかなり増える事になる?
- 多様体最適化っぽい事をしているように見えるが何か技術的な関連はある?
- 式(3.46)がレトラクションっぽい
- ぽいけど、それ以上はわからん
3.5.1節 あたり
- 式(3.60)って具体的にどんなもの?
- $i$はデータ集合で分けられる?
- $i$は$f_i$の関数形によって分けられる?
3.5.3節 あたり
- なんでノイズ入れたら良くなるんだっけ?
- ->教科書では、通常のDU-GD法とDU-NoisyGD法の比較を陽にやっているわけではない。でもパッと見た感じ性能は落ちてる?
- 教科書では、適当な問題設定で挙動を観察しているに過ぎないが、本当はもっとDU-NoisyGDを使うべき問題設定がある説
- ->教科書では、通常のDU-GD法とDU-NoisyGD法の比較を陽にやっているわけではない。でもパッと見た感じ性能は落ちてる?
- P79最後に”ステップサイズを減少させるスケジュールの設計”とあるが、何かユーザがその気持ちを設計している?
- ->あくまで、学習によって結果的に$\eta$の列が得られた事をスケジュールと表現しているだけで、ユーザの意図的なものは反映されていないと思う
3.6.2節 あたり
- 非凸最適化問題に対する王道的な解法ってどんなの?
- シミュレーテッドアニーリング法などのヒューリスティクス手法
- パウエル法
- BFGS法
- これらは初期値依存なことが多く、その際初期値だけ貪欲法などの別手法で解くこともある
- 自分が局所解に落ちているか否かを判別出来るかどうかで、やれる事が変わってきそう
- 教科書の手法は式(3.74)をみると、ラストリギンの大域最小点を使っているので、間接的にこれが判別できる問題設定になっていると言っていい?
- 大域最小点が分からないときは、上に挙げた方法などで得た準最適解を使うなどの方法になる?
- 教科書の手法は式(3.74)をみると、ラストリギンの大域最小点を使っているので、間接的にこれが判別できる問題設定になっていると言っていい?
個人的な気づきなど
- NoisyGD法の説明で出てくる最適化問題はGD法のそれと違うものだと思いこんでいたが、改めて考えてみると別に問題設定自体は同じようだとわかった。
プログラムでの理解
問題設定
- 以下のようなボックス制約付き最小化問題を考える
- $\min_{\bm{x}\in\mathbb{R}^n} \frac{1}{2}||\bm{y}-\bm{A}\bm{x}||^2$ subject to $-1\le\bm{x}\le1$
- $\bm{A}$は定数行列で事前に与えられている
- $\bm{y}$は観測値で$\mathcal{N}(\bm{0},\bm{I})$に従って生成されている
- $\min_{\bm{x}\in\mathbb{R}^n} \frac{1}{2}||\bm{y}-\bm{A}\bm{x}||^2$ subject to $-1\le\bm{x}\le1$
- これに対し以下のような繰り返し計算を収束させ解を得る
$$
\bm{z}^{(t)}=\bm{x}^{(t)}-\eta^{(t)}\bm{A}^{\top}(\bm{A}\bm{x}^{(t)}-\bm{y})
$$
$$
\bm{x}^{(t+1)}=\text{hardtanh}(\bm{z}^{(t)})
$$
-
これはボックス制約を満たすためにhardtanh関数を射影関数に選んだ射影勾配法である
-
深層展開版であるDU-PGD法では$\eta$を学習パラメータとして、事前にサンプリングした$\bm{y}$と、それに対応するQPソルバーで算出した解を使って、学習を行い、学習後に新たに与えられる$\bm{y}$に対しても高速に勾配を下れるようにする
Pythonで実装
必要ライブラリインポート
import numpy as np
import jax
import jax.numpy as jnp
from jaxopt import BoxOSQP
from jax.example_libraries import optimizers
from tqdm.notebook import trange
from functools import partial
問題設定
n = 10
m = 20
K = 200
A = np.random.randn(m,n)
A = jnp.array(A)
- $\bm{A}$は最初に標準ガウス分布から生成して固定する
学習データになる最適解の算出関数の定義
def get_opt(Q, c, A, l, u):
qp = BoxOSQP()
sol = qp.run(params_obj=(Q, c), params_eq=A, params_ineq=(l, u)).params
return sol.primal[0] #双対な解のうちのひとつ
@jax.jit
def gen_x_opt(y):
x_opt = get_opt(A.T@A, -y@A, jnp.eye(n), -jnp.ones(n), jnp.ones(n))
return x_opt
batch_gen_x_opt = jax.vmap(gen_x_opt, in_axes=[-1], out_axes=-1)
-
ここでは、制約付き二次計画問題を解くための汎用ソルバーとして、Jaxに内蔵されたOSQPを利用する
- ※OSQPはJax関係なく開発されているオープンソースのQPソルバー
- これ単体でもかなり有用なツールなので、興味ある人は弄ってみて欲しい
- 二次計画問題を以下の形式に一般化したときの行列とベクトル($\bm{Q},\bm{c},\bm{A},\bm{l},\bm{u}$)を引数として与えることで簡単に解を得ることができる
$
\min \frac{1}{2}\bm{x}^{\top}\bm{Q}\bm{x}+\bm{c}^{\top}\bm{x}
$
$
st. \bm{Ax} = \bm{z}, \bm{l} \leq \bm{z} \leq \bm{u}
$
- ※OSQPはJax関係なく開発されているオープンソースのQPソルバー
-
いつものように、
jax.vmap
を使って複数個の入力に対してバッチを返すようにしている
射影勾配法の反復処理の定義
def DU_ProjectedGD(max_itr, y, eta):
x = jnp.zeros((n, K))
for i in range(max_itr):
x -= eta[i] * A.T @ (A@x - y)
x = jax.nn.hard_tanh(x)
return x
- hardtanh関数はJaxに実装されている
jax.nn.hard_tanh()
を利用した
誤差関数と学習関数の定義
def get_dot(x):
return x @ x.T
batch_get_dot = jax.vmap(get_dot, in_axes=-1, out_axes=-1)
def loss(x_opt, max_itr, y, eta):
x_hat = DU_ProjectedGD(max_itr, y, eta)
return np.sum(batch_get_dot(x_opt - x_hat))/K
adam_lr = 3e-3
num_itr = 25
max_inner = 40
opt_init, opt_update, get_params = optimizers.adam(adam_lr)
@partial(jax.jit, static_argnums=1)
def step(x_opt, max_itr, y, step_num, opt_state):
value, grads = jax.value_and_grad(loss, argnums=-1)(x_opt, max_itr, y, get_params(opt_state))
new_opt_state = opt_update(step_num, grads, opt_state)
return value, new_opt_state
def train(eta):
opt_state = opt_init(eta)
for itr in trange(num_itr, leave=False):
for i in range(max_inner):
y = np.random.randn(m, K)
x_opt = jnp.array(batch_gen_x_opt(y))
value, opt_state = step(x_opt, itr+1, y, i, opt_state)
print("\r"+"\rloss:{}".format(value), end=" ")
return get_params(opt_state)
- この辺は、前々回あたりから大体同じ実装
- ざっくりいうと、インクリメンタル学習で少しずつ勾配降下回数を増やしながら、更にインナーループとして、同じ降下回数で何回か学習を回す
- 今回は周期条件は課さない
- ざっくりいうと、インクリメンタル学習で少しずつ勾配降下回数を増やしながら、更にインナーループとして、同じ降下回数で何回か学習を回す
学習と評価
eta_init = jnp.zeros(num_itr)
eta_trained = train(eta_init)
def comp_mse_DUPGD(max_itr, eta):
y = np.random.randn(m, K)
x_opt = jnp.array(batch_gen_x_opt(y))
return loss(x_opt, max_itr, y, eta)
DUPGD_mse = [comp_mse_DUPGD(i+1, eta_trained) for i in range(num_itr)]
- 学習後に新しい問題インスタンスとして学習時とは違う$\bm{y}$を生成し勾配降下の様子をみる
結果
その他
- 特になし
バックナンバー
- 『モデルベース深層学習と深層展開』読み会 レポート(開催前準備編)
- 『モデルベース深層学習と深層展開』読み会レポート#0
- 『モデルベース深層学習と深層展開』読み会レポート#1
- 『モデルベース深層学習と深層展開』読み会レポート#2
- 『モデルベース深層学習と深層展開』読み会レポート#3
- 『モデルベース深層学習と深層展開』読み会レポート#4