はじめに
- 私は近接勾配法を若いころに本で読んだのですが、正直何言ってるのかわからなかった
- 最近、ある程度理解できたのでまとめます
解きたい問題
微分可能な凸関数$f(w)$と、微分不可能な凸関数$g(w)$で構成される目的関数に対し、次のような$\boldsymbol{w}^*$を解く問題を考えます:
\def\argmax{\mathop{\rm arg~max}\limits}
\def\argmin{\mathop{\rm arg~min}\limits}
\def\bw{\boldsymbol{w}}
\bw^* = \argmin_w\ (f(\bw) + g(\bw))
例: lassoの問題
\bw^* = \argmin_w\ \left(\frac{1}{2}\|\boldsymbol{y}-A\boldsymbol{w}\|_2^2 + \lambda\|\bw\|_1\right)
この場合$f, g$は、$f(\bw)=\frac{1}{2}\|\boldsymbol{y}-A\boldsymbol{w}\|_2^2,\ g(\bw) = \lambda\|\bw\|_1$で定義されています。
こういった問題を解くためのアルゴリズムが近接勾配法です。
解き方の考え方
上で述べた問題に対して最適化方法を考えます。
近接勾配法のキモとなる考え方は、目的関数の上界を作り、その上界の中で最小化することを繰り返すことです。1
この操作をすると、「目的関数が凸関数の場合、目的関数の最小値に到達すること」が一般的に証明されています。23
このやり方の問題は2点あります:
問題1. 「上界はどうやって作る?」
問題2. 「上界の最小値はどうやって計算する?」
今のところ$f, g$には凸関数という性質くらいしか仮定していないので、上界に関して議論しようがありません。また、上界を作っても、それが複雑な関数になったら上界の最小値すら計算できないかもしれません。
この問題を解決するために、飛び道具を出します。それが$L$-平滑性です。
L-平滑性
いろいろ喋ると面倒なので、簡単にまとめます。
この問題では$f$に対して $L$-平滑性を仮定します。これは「$f$の微分の変化率がそんなにやばくなくて、定数$L$で抑えられる」という仮定4で、問題1, 問題2を一気に解決します。
ちょっと省略しますが、一般に$L$-平滑な凸関数$f$は任意の点$\bw_k$で次のような上界を持ちます5:
f(\bw) \leq f(\bw_k) + \nabla f(\bw_k)^\mathrm{T}(\bw-\bw_k) + \frac{L}{2}\|\bw-\bw_k\|_2^2
$\bw_k$は定数であることに注意すると、この右辺の上界は$\bw$の2次関数であることがわかります。2次関数は解析的に扱いやすく、その最小点も簡単に求めることができます。
この仮定により、目的関数$f(\bw)+g(\bw)$の上界は作れて
f(\bw)+g(\bw) \leq f(\bw_k) + \nabla f(\bw_k)^\mathrm{T}(\bw-\bw_k) + \frac{L}{2}\|\bw-\bw_k\|_2^2 + g(\bw)
となります(問題1が解決)。さらに上界は2次関数と$g(\bw)$の和なので、それなりに最小点が求められそうです(問題2も解決?)。
近接勾配法の導出
…ということで$L$-平滑性の議論はかなり飛ばしましたが、$f$に$L$-平滑性を仮定しておけば、上で述べた「上界を作り、その上界の中で最小化することを繰り返す」方法の考え方が使えそうです。
計算していきましょう。
「上界を作り、最小化する」の数学的表現
$\bw_k$の更新1ステップを考えます。
「上界を作り、その上界の中で最小化する」という操作はまとめて数式で言うと単に
\bw_{k+1} = \argmin_\bw\ 上界_{\bw_k}(\bw)
です。よって今回の場合は
\bw_{k+1} = \argmin_\bw\ \left( f(\bw_k) + \nabla f(\bw_k)^\mathrm{T}(\bw-\bw_k) + \frac{L}{2}\|\bw-\bw_k\|_2^2 + g(\bw)\right)
となります。
更新1ステップ分の導出の詳細
具体的に計算を進めると
\begin{align}
\bw_{k+1} &= \argmin_{w}\ \left( f(\bw_k) + \nabla f(\bw_k)^\mathrm{T}(\bw-\bw_k) + \frac{L}{2}\|\bw-\bw_k\|_2^2 + g(\bw)\right)\\
&= \argmin_{w}\ \left( \nabla f(\bw_k)^\mathrm{T}(\bw-\bw_k) + \frac{L}{2}\|\bw-\bw_k\|_2^2 + g(\bw) \right) \qquad(定数を無視)\\
&= \argmin_{w}\ \left( g(\bw) + \frac{L}{2}\|\bw-\bw_k\|_2^2 + \nabla f(\bw_k)^\mathrm{T}(\bw-\bw_k) \right) \qquad(並び替え)\\
&= \argmin_{w}\ \left( g(\bw) + \frac{L}{2}\|\bw-\bw_k\|_2^2 + \nabla f(\bw_k)^\mathrm{T}(\bw-\bw_k) + \frac{L}{2}\left\|\frac{1}{L}\nabla f(\bw_k)\right\|_2^2 \right) \qquad(定数を加える)\\
&= \argmin_{w}\ \left( g(\bw) + \frac{L}{2}\left\|(\bw-\bw_k)+\frac{1}{L}\nabla f(\bw_k)\right\|_2^2 \right) \qquad(平方完成)\\
&= \argmin_{w}\ \left( g(\bw) + \frac{L}{2}\left\|\bw-\left(\bw_k-\frac{1}{L}\nabla f(\bw_k)\right)\right\|_2^2 \right) \qquad(整理)\\
&= \argmin_{w}\ \left( \frac{1}{L}g(\bw) + \frac{1}{2}\left\|\bw-\left(\bw_k-\frac{1}{L}\nabla f(\bw_k)\right)\right\|_2^2 \right) \qquad(定数倍)\\
&= \argmin_{w}\ \left( \eta g(\bw) + \frac{1}{2}\left\|\bw-\left(\bw_k-\eta\nabla f(\bw_k)\right)\right\|_2^2 \right) \qquad\left(\eta := \frac{1}{L} で置き換え\right)\\
&= \argmin_{w}\ \left( \eta g(\bw) + \frac{1}{2}\left\|\bw-\bw'\right\|_2^2 \right) \qquad\left(\bw' := \bw_k-\eta\nabla f(\bw_k) で置き換え\right)\\
\end{align}
となります。
導出結果
長くなりましたが、まとめると更新ステップは次のような計算になりそうです。
導出した更新ステップ:
① $\bw' = \bw_k-\eta\nabla f(\bw_k)$ を計算
② $\bw_{k+1} = \argmin_{w}\ \left( \eta g(\bw) + \frac{1}{2}\left\|\bw-\bw'\right\|_2^2 \right)$ で更新
色々計算しましたが、更新ステップの②はややこしいままです。
$g$の特性を盛り込まないとこれ以上は進めなさそうです。
近接写像の導入
ところで、更新ステップの①$\bw' = \bw_k-\eta\nabla f(\bw_k)$はよく見ると$f$の最急勾配降下法そのものです。
これを考えると、元の「上界を作り、その上界の中で最小化する」という考え方から、更新ステップ②は$g$を考慮しながら遠くに行きすぎないように補正をしていると考えられます。
こういう汚い計算はとりあえず名前を付けてあげると分かった気になります。
近いところに補正をしているということで「近接写像」とでも名付けましょう。近接写像$\mathrm{prox}_\phi(\bw')$を次のように定義します:
\mathrm{prox}_\phi(\bw') := \argmin_w\ \left(\phi(\bw) + \frac{1}{2}\left\|\bw-\bw'\right\|_2^2 \right)
これを用いてまとめた更新ステップが近接勾配法です。つまり、以下の更新ステップで計算します。
近接勾配法の更新ステップ:
① $\bw' = \bw_k-\eta\nabla f(\bw_k)$ を計算
② $\bw_{k+1} = \mathrm{prox}_{\eta g}(\bw')$ で更新
とってもすっきりしましたね!これにて近接勾配法の導出完了です。6
このように勾配法と近接写像を繰り返すため、近接勾配法と呼ばれています。7
おわりに
- お気持ちだけなら結構簡単でしたね
- L平滑性の議論飛ばしてすみません、、、上界の議論をしたいときに頻繁に出てくるので勉強しておくといいと思います。
-
これは上界最小化アルゴリズム(MMアルゴリズム)で用いられる考え方です。つまり近接勾配法はMMアルゴリズムを特定の問題(つまり「解きたい問題」の章で述べた問題)に適用することで導出できます。 ↩
-
例えば「金森敬文, 鈴木大慈, 竹内一郎, 佐藤一誠. 機械学習のための連続最適化(機械学習プロフェッショナルシリーズ), 講談社, 2016」のpp.223-224を参照のこと。
単調減少性を担保するには、点$w_t$で構成する上界$u_{w_t}(w)$は2つの条件が必要です: 1. 任意の$w$において、$f(w) \leq u_{w_t}(w)$; 2. $f(w_t) = u_{w_t}(w_t)$.
つまり、点$w_t$で上界を構成したとき、$w_t$では元の関数fと上界が、くっついてないといけません。これは上で載せたGIF画像でもそのように描画してあります。 ↩ -
今回の場合、目的関数$f(w)+g(w)$は凸関数です。$f(w), g(w)$がそれぞれ凸関数で、「凸関数の和は凸関数」という性質から導けます。 ↩
-
「関数$f$が$L$-平滑である」の具体的な定義は、任意の$\boldsymbol{x}, \boldsymbol{y}$に対して$\|\nabla f(\boldsymbol{x})-\nabla f(\boldsymbol{y})\| \leq L \|\boldsymbol{x}-\boldsymbol{y}\|$が成り立つことです。 ↩
-
簡単に証明できます。例えば「金森敬文, 鈴木大慈, 竹内一郎, 佐藤一誠. 機械学習のための連続最適化(機械学習プロフェッショナルシリーズ), 講談社, 2016」のpp.35を参照のこと。 ↩
-
近接写像はスパース推定系の問題だと個別に解析されていて、具体的な形が知られています。例えばlassoの問題なら解析的に導出できて「ソフト閾値関数」という関数が出てきます。こちらのSekinoさんの記事などでも紹介されています。 ↩
-
この記事では「近接写像」をお気持ちで導入しましたが、実は近接点法という一般的な最適化手法で近接写像は定義されています(説明が煩雑になるため省略しました)。近接勾配法では導出する過程で自動的に「勾配法」と「近接点法」が出てくるため、「近接勾配法」と呼ばれています。 ↩