アルゴリズム
機械学習
数学
最適化

近接勾配法(Proximal Gradient Method)

はじめに

効くか効かないかわからない特徴量が大量にあって、中にはいくつか効くものがきっとある・・・というときに、L1正則化やGroup LASSOが用いられます。これらは微分不可能な点を含むため、通常の勾配法では解けません。
そこで、微分不可能な点を含む凸関数最適化の一手法である近接勾配法について、勉強したことをまとめてみました。

近接勾配法の更新式

微分可能な凸関数$f(\boldsymbol{x})$と微分不可能な点を含む凸関数$g(\boldsymbol{x})$に対して、$F(\boldsymbol{x})=f(\boldsymbol{x})+g(\boldsymbol{x})$を最小化します。
ステップ$k$における点$\boldsymbol{x}_k$とステップ幅$\eta$に対して、近接勾配法は

\begin{align}
\boldsymbol{x}_{k+1}
 &= {\rm prox}_{\eta g}( \boldsymbol{x}_k - \eta \bigtriangledown f(\boldsymbol{x}_k) ) \\
{\rm prox}_g(\boldsymbol{y})
 &\equiv {\rm argmin}_\boldsymbol{x} \left\{ g(\boldsymbol{x}) + \frac{1}{2} \parallel \boldsymbol{x} - \boldsymbol{y} \parallel^2 \right\}
\end{align}

と更新します。パッと見仰々しいのですが、式を読み込んでいきます。

Proximal Operator

${\rm prox}_g$は$g$のproximal operator(近接写像)と呼ばれる写像で、入力$\boldsymbol{y}$との二乗距離による罰則項$\frac{1}{2} \parallel \boldsymbol{x} - \boldsymbol{y} \parallel^2$と$g(\boldsymbol{x})$の和を最小化する点に写像されます。微分可能な$f$の勾配に従って通常の勾配法のように更新しておいて、微分不可能な$g$についてproximal operatorで考慮する、という処理になっています。制約付き最適化で、制約を無視して勾配方向に更新してから、実行可能領域に射影する勾配射影法(Gradient Projection Method)と同様の手続きになっていて、proximal operatorは射影の一般化になっています。

近接勾配法の更新式を展開してみる

近接勾配法の更新式(再掲):

\begin{align}
\boldsymbol{x}_{k+1}
 &= {\rm prox}_{\eta g}( \boldsymbol{x}_k - \eta \bigtriangledown f(\boldsymbol{x}_k) ) \\
{\rm prox}_g(\boldsymbol{y})
 &\equiv {\rm argmin}_\boldsymbol{x} \left\{ g(\boldsymbol{x}) + \frac{1}{2} \parallel \boldsymbol{x} - \boldsymbol{y} \parallel^2 \right\}
\end{align}

で${\rm prox}$の下付き添え字が、上の更新式では$\eta g$、下のproximal operatorの定義の式では$g$になっていることにご注意ください。$\eta$を忘れるとうまくいきません!
$\eta$付きの${\rm prox}_{\eta g}( \boldsymbol{y} ) $は

\begin{align}
{\rm prox}_{\eta g}(\boldsymbol{y})
 &= {\rm argmin}_\boldsymbol{x} \left\{\eta g(\boldsymbol{x}) + \frac{1}{2} \parallel \boldsymbol{x} - \boldsymbol{y} \parallel^2 \right\}
\end{align}

のことです。勾配法のステップ幅と同じ$\eta$で罰則項の効き目を調整していますね。
展開してみると、

\begin{align}
\boldsymbol{x}_{k+1}
 &= {\rm prox}_{\eta g}( \boldsymbol{x}_k - \eta \bigtriangledown f(\boldsymbol{x}_k) ) \\
 &= {\rm argmin}_\boldsymbol{x} \left(
     \eta g(\boldsymbol{x}) + \frac{1}{2} \parallel \boldsymbol{x} - (\boldsymbol{x}_k - \eta \bigtriangledown f(\boldsymbol{x}_k)) \parallel^2
    \right) \\
 &= {\rm argmin}_\boldsymbol{x} \left(
     g(\boldsymbol{x}) + f(\boldsymbol{x}_k) + \frac{1}{2 \eta} \parallel \boldsymbol{x} - \boldsymbol{x}_k + \eta \bigtriangledown f(\boldsymbol{x}_k)) \parallel^2
    \right) \\
 &= {\rm argmin}_\boldsymbol{x} \left(
     g(\boldsymbol{x}) + f(\boldsymbol{x}_k) + \bigtriangledown f(\boldsymbol{x}_k)^T (\boldsymbol{x} - \boldsymbol{x}_k) + \frac{1}{2 \eta} \parallel \boldsymbol{x} - \boldsymbol{x}_k \parallel^2
    \right) \\
 &= {\rm argmin}_\boldsymbol{x} \left( g(\boldsymbol{x}) + \hat{f}_{\eta}(\boldsymbol{x};\boldsymbol{x}_k) \right) \\
\hat{f}_\eta(\boldsymbol{x};\boldsymbol{y})
 &\equiv f(\boldsymbol{y}) + \bigtriangledown f(\boldsymbol{y})^T (\boldsymbol{x} - \boldsymbol{y}) + \frac{1}{2\eta} \parallel \boldsymbol{x} - \boldsymbol{y} \parallel^2
\end{align}

となります(もっと詳しく)。これより近接勾配法の更新は、$g(\boldsymbol{x})$と、$f(\boldsymbol{x})$の$\boldsymbol{x}_k$近傍での二次近似$\hat{f}_\eta(\boldsymbol{x};\boldsymbol{x}_k)$の和を最小化していることがわかります。
$g(\boldsymbol{x})$の項を除くと、

\begin{align}
\boldsymbol{x}_{k+1}
 &= {\rm argmin}_\boldsymbol{x} \hat{f}_\eta(\boldsymbol{x};\boldsymbol{x}_k) \\
 &= \boldsymbol{x}_k - \eta \bigtriangledown f(\boldsymbol{x}_k)
\end{align}

となり、これは$f(\boldsymbol{x})$に対する通常の勾配法です。

ステップ幅の決め方

関数$f$が、ある$L > 0$が存在して、任意の$\boldsymbol{x}$、$\boldsymbol{y}$に対して

|f(\boldsymbol{x})-f(\boldsymbol{y})| \leq L \parallel \boldsymbol{x} - \boldsymbol{y} \parallel

が成り立つことを$f$がリプシッツ連続であるといい、定数$L$をリプシッツ定数と言います。また、勾配$\bigtriangledown f$が、リプシッツ定数$\gamma$のリプシッツ連続であることを$\gamma$-平滑と言います。

$f$が$\gamma$-平滑な凸関数であれば

f(\boldsymbol{x})
 \leq \hat{f}_{1/\gamma}(\boldsymbol{x};\boldsymbol{y})
 = f(\boldsymbol{y}) + \bigtriangledown f(\boldsymbol{y})^T (\boldsymbol{x} - \boldsymbol{y}) + \frac{\gamma}{2} \parallel \boldsymbol{x} - \boldsymbol{y} \parallel^2

となります([1] p.35)。すなわち、$\gamma$-平滑な凸関数は、$\hat{f}_{1/\gamma}(\boldsymbol{x};\boldsymbol{y})$によって目的関数を抑えることができます。近接勾配法のステップ幅としては$\eta=1/\gamma$とすれば、$f$の上界を最小化していくことができます。実際には$\gamma$を直接求めることは難しいため、各ステップにおいて$f(\boldsymbol{x}_{k+1}) \leq \hat{f}_\eta(\boldsymbol{x}_{k+1};\boldsymbol{x}_k)$をチェックして、これが条件を満たすまで$\eta \leftarrow \beta \eta$($0 < \beta < 1$)として$\boldsymbol{x}_{k+1}$を求め直す、バックトラッキング法が用いられます。

近接勾配法のアルゴリズム

ステップ幅を決めるバックトラッキング法を組み込むと、アルゴリズムは次のようになります。

\begin{align}
&\eta = \eta_0 \\
&{\rm for} \; k=1 \; {\rm to} \; K \; {\rm do} \\
&\quad {\rm while \; not \; broken \; do} \\
&\qquad \boldsymbol{x}_k = {\rm prox}_{\eta g}( \boldsymbol{x}_{k-1} - \eta \bigtriangledown f(\boldsymbol{x}_{k-1}) ) \\
&\qquad {\rm if} \; f(\boldsymbol{x}_k) \leq \hat{f}_\eta(\boldsymbol{x}_k;\boldsymbol{x}_{k-1}) \; {\rm then}\\
&\qquad \quad {\rm break} \\
&\qquad {\rm else} \\
&\qquad \quad \eta \leftarrow \beta \eta \\
&\qquad {\rm end \; if} \\
&\quad {\rm end \; while} \\
&{\rm end \; for}
\end{align}

Proximal Operatorの解析解

L2正則化のProximal Operator

L2正則化の場合、近接勾配法は必要ありませんが練習として解いてみます。

\begin{align}
{\rm prox}_{\frac{\lambda}{2} \parallel \cdot \parallel^2}(\boldsymbol{x})
 &= {\rm argmin}_\boldsymbol{y} \left\{ \frac{\lambda}{2} \parallel \boldsymbol{y} \parallel^2 + \frac{1}{2} \parallel \boldsymbol{x} - \boldsymbol{y} \parallel^2 \right\}
\end{align}

カッコ内を$\boldsymbol{y}$で偏微分すると$(1+\lambda)\boldsymbol{y} - \boldsymbol{x}$なので、$\boldsymbol{y}$について解いて

\begin{align}
{\rm prox}_{\frac{\lambda}{2} \parallel \cdot \parallel^2}(\boldsymbol{x})
 = \frac{1}{1+\lambda} \boldsymbol{x}
\end{align}

を得ます。

L1正則化のProximal Operator

次はL1正則化の場合です。

\begin{align}
{\rm prox}_{\lambda \parallel \cdot \parallel_1}(\boldsymbol{x})
 &= {\rm argmin}_\boldsymbol{y} \left\{
      \lambda \parallel \boldsymbol{y} \parallel_1
    + \frac{1}{2} \parallel \boldsymbol{x} - \boldsymbol{y} \parallel^2
    \right\}
\end{align}

カッコ内を$y_i$で偏微分すると

\begin{align}
 \left\{
  \begin{array}{ll}
    \lambda + y_i - x_i & y_i \geq 0 \\
   -\lambda + y_i - x_i & y_i < 0
  \end{array}
 \right. \\
\end{align}

なので、$y_i$について解いて

\begin{align}
\Big( {\rm prox}_{\lambda \parallel \cdot \parallel_1}(\boldsymbol{x}) \Big)_i
 &= \left\{
  \begin{array}{ll}
   x_i - \lambda & x_i > \lambda \\
   0 & |x_i| \leq \lambda \\
   x_i + \lambda & x_i < -\lambda
  \end{array}
 \right. \\
 &= {\rm sign}(x_i) (|x_i| - \lambda)^+
\end{align}

を得ます。

Group LASSOのProximal Operator

重複のないGroup LASSOの場合は、(導出を追えていないので)結果だけ載せておきます。

\begin{align}
{\rm prox}_{\lambda \parallel \cdot \parallel}(\boldsymbol{x})
 &= \left\{
  \begin{array}{ll}
   \boldsymbol{x} - \lambda \frac{\boldsymbol{x}}{\parallel \boldsymbol{x} \parallel} & \parallel \boldsymbol{x} \parallel > \lambda \\
   \boldsymbol{0} & \parallel \boldsymbol{x} \parallel \leq \lambda
  \end{array}
 \right. \\
 &= \frac{\boldsymbol{x}}{\parallel \boldsymbol{x} \parallel} (\parallel \boldsymbol{x} \parallel - \lambda)^+
\end{align}

近接勾配法の更新式の展開(補足)

近接勾配法の更新式の展開ですが、もう少し詳しく追っていきます。

\begin{align}
\boldsymbol{x}_{k+1}
 &= {\rm prox}_{\eta g}( \boldsymbol{x}_k - \eta \bigtriangledown f(\boldsymbol{x}_k) ) \\
 &= {\rm argmin}_\boldsymbol{x} \left(
     \eta g(\boldsymbol{x}) + \frac{1}{2} \parallel \boldsymbol{x} - (\boldsymbol{x}_k - \eta \bigtriangledown f(\boldsymbol{x}_k)) \parallel^2
    \right) \\
 &= {\rm argmin}_\boldsymbol{x} \left(
     g(\boldsymbol{x}) + f(\boldsymbol{x}_k) + \frac{1}{2 \eta} \parallel \boldsymbol{x} - \boldsymbol{x}_k + \eta \bigtriangledown f(\boldsymbol{x}_k)) \parallel^2
    \right) \\
 &= {\rm argmin}_\boldsymbol{x} \left(
     g(\boldsymbol{x}) + f(\boldsymbol{x}_k) + \bigtriangledown f(\boldsymbol{x}_k)^T (\boldsymbol{x} - \boldsymbol{x}_k) + \frac{1}{2 \eta} \parallel \boldsymbol{x} - \boldsymbol{x}_k \parallel^2
    \right) \\
 &= {\rm argmin}_\boldsymbol{x} \left( g(\boldsymbol{x}) + \hat{f}_{\eta}(\boldsymbol{x};\boldsymbol{x}_k) \right) \\
\hat{f}_\eta(\boldsymbol{x};\boldsymbol{y})
 &\equiv f(\boldsymbol{y}) + \bigtriangledown f(\boldsymbol{y})^T (\boldsymbol{x} - \boldsymbol{y}) + \frac{1}{2\eta} \parallel \boldsymbol{x} - \boldsymbol{y} \parallel^2
\end{align}

3つ目の等号ではカッコ内全体を$\eta$で割って、定数$f(\boldsymbol{x}_k)$を足していますが、これにより解が変わることはありません。

\begin{align}
\parallel \boldsymbol{x} \parallel^2
 &= \boldsymbol{x}^T \boldsymbol{x} \\
\parallel \boldsymbol{a} + \boldsymbol{b} \parallel^2
 &= (\boldsymbol{a} + \boldsymbol{b})^T (\boldsymbol{a} + \boldsymbol{b}) \\
 &= \boldsymbol{a}^T \boldsymbol{a} + \boldsymbol{a}^T \boldsymbol{b} + \boldsymbol{b}^T \boldsymbol{a} + \boldsymbol{b}^T \boldsymbol{b} \\
 &= \parallel \boldsymbol{a} \parallel^2 + 2 \boldsymbol{a}^T \boldsymbol{b} + \parallel \boldsymbol{b} \parallel^2
\end{align}

より、3つ目の等号の第3項は、

\begin{align}
\frac{1}{2 \eta} \parallel (\boldsymbol{x} - \boldsymbol{x}_k) + \eta \bigtriangledown f(\boldsymbol{x}_k)) \parallel^2
 &= \frac{1}{2 \eta} \{ \parallel \boldsymbol{x} - \boldsymbol{x}_k \parallel^2
 + 2 \eta \bigtriangledown f(\boldsymbol{x}_k)^T (\boldsymbol{x} - \boldsymbol{x}_k)
 + \eta^2 \parallel \bigtriangledown f(\boldsymbol{x}_k) \parallel^2 \}
\end{align}

であり、解に影響のない定数$\frac{\eta}{2} \parallel \bigtriangledown f(\boldsymbol{x}) \parallel^2$を省くと4つ目の等号になります。

参考文献

[1] 金森敬文, 鈴木大慈, 竹内一郎, 佐藤一誠. 機械学習のための連続最適化(機械学習プロフェッショナルシリーズ), 講談社, 2016.