はじめに
前回・前々回の記事では、バッチの近接勾配法とその加速法について紹介しました。バッチの近接勾配法では、各更新のたびに、学習データの全サンプルに対する勾配を求めて更新していました。バッチ勾配法は一回一回の更新で、確実に目的関数を減らしていくことができたり、収束の具合をモニタリングできるので安心感があるのですが、学習データが大きくなってくると一回一回の更新が重くなりすぎてつらくなってきます。
そこで今回は1サンプルに対する勾配計算で更新するオンライン学習法の一手法であるRegularized Dual Averaging [1]を紹介します。バッチが全サンプルに対する勾配計算をして1回しか更新しないのに対して、全$N$サンプルに対する勾配計算をする頃にはN回更新されるので、実行時間に対する収束レートとしてお得[2]になります(※)。確率的近接勾配法でだいたい収束させてからバッチ勾配法で収束させる、とかの使い方もできるかと思います。
※少し実体験を書きますと、バッチ更新が可能なデータの規模であったりモデル(たとえば一回データをスキャンして統計量を計算したら、毎回同じ統計量を使いまわして更新できる(データをスキャンし直す必要がない)線形二乗誤差系のモデルとか)では、更新に必要な計算を並列化できたり、オンライン学習の挙動に悩まされずに安定して解が得られるバッチ更新則の方が良いと思います。とはいえ、バッチ更新が現実的には困難、というようなモデルも多いので、ここで紹介するオンライン学習法も有用です。が、オンライン学習も並列化しないと実際にはツライです。
復習:近接勾配法
近接勾配法は、微分可能な凸関数$f(\boldsymbol{x})$と微分不可能な点を含む凸関数$g(\boldsymbol{x})$に対して、
\begin{align}
\boldsymbol{x}_{k+1}
&= {\rm prox}_{\eta g}( \boldsymbol{x}_k - \eta \bigtriangledown f(\boldsymbol{x}_k) ) \\
{\rm prox}_g(\boldsymbol{y})
&\equiv {\rm argmin}_\boldsymbol{x} \left\{ g(\boldsymbol{x}) + \frac{1}{2} \parallel \boldsymbol{x} - \boldsymbol{y} \parallel^2 \right\}
\end{align}
と更新しました。あるいは、式展開をして
\begin{align}
\boldsymbol{x}_{k+1}
&= {\rm argmin}_\boldsymbol{x} \left(
\bigtriangledown f(\boldsymbol{x}_k)^T (\boldsymbol{x} - \boldsymbol{x}_k) + g(\boldsymbol{x}) + \frac{1}{2 \eta} \parallel \boldsymbol{x} - \boldsymbol{x}_k \parallel^2
\right) \\
\end{align}
の形式で記載されているものもよく見ます。
機械学習の文脈に馴染みやすいよう最適化対象のパラメータを$\boldsymbol{w}$、サンプル$n$に対する損失関数を$f_n(\boldsymbol{w})$、正則化項を$\lambda \Psi(\boldsymbol{w})$とすると、目的関数は
\begin{align}
\sum_{n=1}^N f_n(\boldsymbol{w}) + \lambda \Psi(\boldsymbol{w})
\end{align}
の形になることが多いです。さらに、オンライン学習の文脈に馴染みやすいように、『1サンプルあたり』の正則化項$\lambda \psi(\boldsymbol{w})=\lambda \Psi(\boldsymbol{w})/N$を使うことにして、次の目的関数最小化を考えます。
\begin{align}
\frac{1}{N} \sum_{n=1}^N f_n(\boldsymbol{w}) + \lambda \psi(\boldsymbol{w})
\end{align}
バッチの近接勾配法の更新式で$f(\boldsymbol{x}_k) \rightarrow \sum_{n=1}^N f_n(\boldsymbol{w}_k)$、$g(\boldsymbol{x}) \rightarrow N \lambda \psi(\boldsymbol{w})$と置き換えると
\begin{align}
\boldsymbol{w}_{k+1}
&= {\rm argmin}_\boldsymbol{w} \left(
\frac{1}{N} \sum_{n=1}^N \bigtriangledown f_n(\boldsymbol{w}_k)^T (\boldsymbol{w} - \boldsymbol{w}_k) + \lambda \psi(\boldsymbol{w}) + \frac{1}{2N \eta} \parallel \boldsymbol{w} - \boldsymbol{w}_k \parallel^2
\right) \\
\end{align}
となります。第一項で、全てのサンプル$n=1,\cdots,N$について現在の点$\boldsymbol{w}_k$での勾配$\bigtriangledown f_n(\boldsymbol{w}_k)$を計算していますね。
確率的近接勾配法:Regularized Dual Averaging
ステップ$t$におけるパラメータ$\boldsymbol{w}_t$に対して、損失$f_t(\boldsymbol{w})$を被る問題設定におけるRegularized Dual Averaging(RDA)[1]の更新式は次のようになります。
\begin{align}
\boldsymbol{w}_{t+1}
&= {\rm argmin}_\boldsymbol{x} \left(
\frac{1}{t} \sum_{\tau=1}^t \bigtriangledown f_\tau(\boldsymbol{w}_\tau)^T (\boldsymbol{w} - \boldsymbol{w}_\tau) + \lambda \psi(\boldsymbol{w}) + \frac{1}{2t \eta_t} \parallel \boldsymbol{w} \parallel^2
\right) \\
&= {\rm argmin}_\boldsymbol{x} \left(
\bar{g}_t^T \boldsymbol{w} + \lambda \psi(\boldsymbol{w}) + \frac{1}{2t \eta_t} \parallel \boldsymbol{w} \parallel^2
\right) \\
\bar{g}_t
&\equiv \frac{1}{t} \sum_{\tau=1}^t \bigtriangledown f_\tau(\boldsymbol{w}_\tau)
\end{align}
$\bar{g}_t$は$\tau=1,\cdots,t$の各ステップで求めた勾配$\bigtriangledown f_\tau(\boldsymbol{w}_\tau)$の平均です。
近接勾配法の更新式との違いとしてはまず第一項で、現在の点$\boldsymbol{w}_t$についての勾配は、サンプル$t$に対する勾配$\bigtriangledown f_t(\boldsymbol{w}_t)$のみを計算しています。サンプル一つ分の勾配だけ計算する点はStochastic Gradient Descentと一緒ですね。が、過去の勾配との平均を取ったり、更新式に最小化が入っているなどSGDとは大きく違っています。(SGDをそのまま微分不可能な点を含む凸関数最適化に使えるようにしたような手法としてはStochastic Subgradient Methodがあります。が、この手法ではL1正則化などでも厳密に0にはならず、スパース解が得られません。)
また第三項が、近接勾配法では現在の点$\boldsymbol{w}_k$から離れるほどペナルティが掛かる項になっているのに対して、RDAでは$\boldsymbol{w}$のノルムに対してペナルティが掛かっています。(この項を過去の系列$\boldsymbol{w}_\tau$($\tau=1,\cdots,t$)との二乗誤差の和$\sum_{\tau=1}^t \sigma_\tau \parallel \boldsymbol{w} - \boldsymbol{w}_\tau \parallel^2$とするFTRL-Proximal [3]のような手法もあるのですが、性能的には大きく違わないようです。)
バッチの近接勾配法では現在の点$\boldsymbol{w}_k$で近傍での二次近似を最小化している、と解釈できたのですが今回は勾配が平均勾配になっているし、そのような解釈はできません。収束性はRegret解析によって示されています。直感的な解釈は難しいのですが、何にせよ収束性が示されるなら、収束性が良くて、使い勝手が良い(いろいろな正則化項に対して更新式が求まりやすい)手法が望ましいのですが、RDAは使い勝手が素晴らしく良いのです!
式を展開すると
\begin{align}
\boldsymbol{w}_{t+1}
&= {\rm argmin}_\boldsymbol{w} \left(
\bar{g}_t^T \boldsymbol{w} + \lambda \psi(\boldsymbol{w}) + \frac{1}{2t \eta_t} \parallel \boldsymbol{w} \parallel^2
\right) \\
&= {\rm argmin}_\boldsymbol{w} \left(
\lambda \psi(w) + \frac{1}{2t \eta_t} \parallel \boldsymbol{w} + t \eta_t \bar{g}_t \parallel^2
\right) \\
&= {\rm prox}_{t \eta_t \lambda \psi}( -t \eta_t \bar{g}_t ) \\
\bar{g}_t
&= \frac{t-1}{t} \bar{g}_{t-1} + \frac{1}{t} \bigtriangledown f_t(\boldsymbol{w}_t)
\end{align}
となります。3つ目の等号では、Proximal Operatorの定義:
\begin{align}
{\rm prox}_\psi(\boldsymbol{y})
&\equiv {\rm argmin}_\boldsymbol{x} \left\{ \psi(\boldsymbol{x}) + \frac{1}{2} \parallel \boldsymbol{x} - \boldsymbol{y} \parallel^2 \right\}
\end{align}
を使っています。
この式は超すごい! と思うのでもう一度書きます。
\begin{align}
\boldsymbol{w}_{t+1}
&= {\rm prox}_{t \eta_t \lambda \psi}( -t \eta_t \bar{g}_t )
\end{align}
平均勾配に対してProximal Operatorを適用するだけです! つまり、Proximal Operatorがわかっている正則化項であれば、超簡単にオンライン学習できちゃうのです。 しかも、Proximal Operatorを適用していることからわかるように、L1正則化などで落ちる係数はExactに0になりスパース解が得られます。ステップ幅は$\eta_t = \gamma / \sqrt{t}$($\gamma > 0$)のように選ぶことで収束レートが$O(\sqrt{t})$になることが示されています。
参考文献
[1] Xiao, Lin. "Dual averaging methods for regularized stochastic learning and online optimization." Journal of Machine Learning Research 11.Oct (2010): 2543-2596.
[2] 鈴木大慈. "機械学習におけるオンライン確率的最適化の理論." 情報処理学会連続セミナー2013 (Slideshare)
[3] McMahan, H. Brendan, et al. "Ad click prediction: a view from the trenches." Proceedings of the 19th ACM SIGKDD international conference on Knowledge discovery and data mining. ACM, 2013.