勾配ブースティング決定木 (Gradient Boosting Decision Tree, GBDT) について理論的な解説をしている書籍が少なかったので、GBDTの中でも他の手法の基礎となっていそうなxgboostの論文を読んでみました。個人的に読んだ内容を整理するついでに記事にしたいと思います。
xgboostの論文を解説した記事は既に複数あるので、真新しい内容はあまりないかと思いますが、気になったことや自分なりの解釈を書きたいと思います。特に全体を通しての流れであったり、手法のモチベーションなどが分かるように書きたいと思います。また、この記事で重視する分かりやすさとは『平易であること』や『直感的であること』ではなく、『論理の流れ』を明確にすることを目指したいと思っています。
なお、この記事で扱うのは論文の全てでなく、xgboostの仕組みを解説した"2. TREE BOOSTING IN A NUTSHELL"を主な対象としています。
はじめに
細かい内容に入る前に、まずはxgboostの仕組みをざっくり確認します。
xgboostはいわゆるアンサンブル学習であり、複数の決定木の結果を合成して最終的な予測値を出力します。結果の合成方法は各決定木の予測値の加算です。つまり、インプットデータ$\textbf{x}_i$と対応する予測値$\hat{y}_i$に対し、決定木の本数を$K$とすると
\hat{y}_i = \sum_{k=1}^{K}f_k(\textbf{x}_i),
\qquad f_k \in \mathcal{F} \tag{1}
と書き表せます。ここで、$f_k$は各決定木の予測値に対応し、$\mathcal{F}$は、
\mathcal{F}=\lbrace f(\textbf{x}) = w_{q(\textbf{x})} \rbrace(q: \mathbb{R}^m \to T, w \in \mathbb{R}^T) \tag{2}
です。$\mathcal{F}$については表記に慣れていないとすんなり理解できないかもしれないので、もう少しだけ説明します。まず、写像$q$は$m$次元実ベクトル空間の集合$T$への写像であり、$m$次元実ベクトル空間はインプットデータに対応し、集合$T$の要素は決定木の各葉 (leaf) に対応します。特徴量のカラムが$m$列のデータを決定木に流すと、どこかしらのリーフノードに辿り着くイメージです。$|T|$(集合$T$のサイズ)はリーフノードの総数に対応します。$w$は各リーフノードの値で、分類木であれば最終的にはどのクラスに属するかというラベルになりますが、xgboostでは回帰木が前提なので、回帰の値(予測値)になります。つまり、$f(\textbf{x})=w_{q(\textbf{x})}$という式は、インプットデータを決定木に流し込み、辿り着いたリーフノードの値を読み取るという通常の回帰木の流れに相当しているわけです。
さて、少し説明がくどくなってしまったかも知れませんが、正直この辺りは論文の図を見てもらうのが一番わかりやすいかと思うので、ピンと来なかった方は論文のFigure 1を見ていただくと良いかと思います。
さて、ここで気になるのはアンサブルされる各決定木の関係性です。結論を先に述べてしまうと、xgboostではイテレーションを回して1個ずつ木を増やしていきます。具体的な方法は後述しますが、各イテレーションで追加する決定木は前回までに生成された木の回帰結果を改善することに集中し、この繰り返しによってxgboostは予測精度を高めていくわけです。
以上を踏まえて、この後xgboostの詳細に入っていきたいと思いますが、最後に解くべき問題を整理しましょう。xgboostの学習過程において、決定しなければならない事項は以下の2つです。
- 各イテレーションで生成する決定木の構造:$q(\textbf{x})$
- 生成した決定木における各リーフノードの予測値:$w_{q(\textbf{x})}$
以降の内容はこの辺りを意識して読んでいただくとより流れがわかりやすいかと思います。
正則化損失関数 (Regularized Learning Objective)
それでは、xgboostの内容の解説に入っていきます。
まずは問題を具体的に定式化しましょう。ということで損失関数を定義しますが、ここで正則化項を考慮した損失関数を考えます。
\mathcal{L} = \sum_il(\hat{y}_i, y_i) + \sum_k \Omega(f_k) \tag{3}
\mathrm{where} \quad \Omega(f) = \gamma T + \frac{1}{2}\lambda \parallel w\parallel^2
$y_i$は正解の値、$l$は微分可能な損失関数です。例えばRMSEとかを想像していただければ良いと思います。右辺第2項が正則化項に対応しており、リーフノードの数 ($T$) とリーフノードの予測値のL2ノルム ($\parallel w\parallel^2$) に対して、罰則を課しています。それぞれ、過学習を抑えることに寄与しそうなことはなんとなく理解できるかと思います。なお、正則化$\Omega$はxgboostで生成される決定木一本あたりに定義されます。
勾配ブースティング (Gradient Tree Boosting)
さて、損失関数が定義できたので、損失を最小化 (最適化) するパラメータを求めていきます。
「はじめに」の最後でも述べましたが、最適化の対象となるパラメータは、①決定木の構造と②リーフノードの値で、これらの値を一本ずつ決定木を増やしながら最適化していきます。この問題を一度に一気に考えるのは大変なのでひとつひとつほぐしていきましょう。
イテレーションに対する漸化式
まずは個々のイテレーションでの最適化を独立した問題として定式化したいので、イテレーション部分を漸化式的に書き直します。現在、$t$ステップ目のイテレーションにいるとして、$t-1$ステップまでの$t-1$本の決定木の構造とリーフノードの値は既に求まっているとします。第$t$ステップの損失関数は、(1)式より、
\mathcal{L}^{(t)} = \sum_{i=1}^nl(y_i, \hat{y}_i^{(t-1)} + f_t(\textbf{x}_i)) + \sum_{t'=1}^{t} \Omega (f_{t'}) \tag{4}
と書け、第$t$ステップにおける最適化パラメータのうち、損失関数に陽に現れている変数は$f_t$と$T$です。そう考えると、右辺第2項の$t'$に関する和のうち、$t-1$までの和は第$t$ステップ時点では確定している値で定数と見做せるので、無視します。
\mathcal{L}^{(t)} = \sum_{i=1}^nl(y_i, \hat{y}_i^{(t-1)} + f_t(\textbf{x}_i)) + \Omega (f_{t}) \tag{5}
これで各イテレーションにおける最適化問題を独立に取り扱う準備ができました。また、この定式化によって、$t$ステップ目の決定木は$t-1$ステップ目までの結果を改善することに特化していることが分かります。
リーフノードの値の最適化
次に各イテレーションにおける決定木の木構造とリーフノードの値の最適化についてですが、問題の依存関係として、リーフノードの値は決定木の構造を与えなければ最適化することができません。逆に、与えられた木構造の中でリーフノードの値の最適解を求める方法が構築できれば、あとはその方法をもって各木構造を比較すれば自ずと最適化ができます。ということで、与えられた木構造においてリーフノードの値を最適化することを考えます。木構造が与えられているというのは、(2)式における$q$ (つまり$T$) が定義されているということです。よって(5)式の損失関数のうち最適化の対象となるパラメータは$f_t$のみとなります。ただし、この最適化問題は素直には扱いにくいので、損失関数を近似することを考えます。具体的には$l$の勾配を用いて、
\mathcal{L}^{(t)} \simeq \sum_{i=1}^n [l(y_i, \hat{y}^{(t-1)}) + g_if_t(\textbf{x}_i) + \frac{1}{2} h_if_t^2(\textbf{x}_i)] + \Omega(f_t), \tag{6}
g_i = \left.\frac{\partial l(y_i, y)}{\partial y}\right|_{y=\hat{y}^{(t-1)}},
h_i = \left.\frac{\partial^2 l(y_i, y)}{\partial y^2}\right|_{y=\hat{y}^{(t-1)}},
と近似します。つまり$f_t$の2次まで展開して近似します。このとき、先ほどと同様に第$t$ステップ時には値が確定している$l(y_i, \hat{y}^{(t-1)})$は定数であり、最適化の観点では無視できるので、
\mathcal{\tilde{L}}^{(t)} = \sum_{i=1}^n[g_if_t(\textbf{x}_i) + \frac{1}{2} h_if_t^2(\textbf{x}_i)] + \Omega(f_t) \tag{7}
と定義し直します。(3)式より右辺第2項の$\Omega(f_t)$も$f_t$については2次式となっており、$\mathcal{\tilde{L}}^{(t)}$は全体として$f_t$の二次関数となっています。ようやく解析的に解けそうになってきましたね(解析的に解くために近似していたわけですが)。あともうひと息です。ここまできたら正則化項も明示的に書いてしまいましょう。
\mathcal{\tilde{L}}^{(t)} = \sum_{i=1}^n[g_if_t(\textbf{x}_i) + \frac{1}{2} h_if_t^2(\textbf{x}_i)] + \gamma T + \frac{1}{2}\lambda \sum_{j=1}^T w_j^2 \tag{8}
さて、上式では和が二箇所で登場していますが、それぞれの和の取り方が揃っていません。初めの和はデータの単位でカウントしており、二個目の和は決定木のリーフノードの単位でカウントしています。いま、決定木の木構造は与えられているとしているため、すべてのデータはいずれかのリーフノードに対応します。そこで和の取り方をリーフノードの単位に揃えましょう。実際、最適化をしたいのはリーフノードの値に対してなので、揃えないと先に進めません。ここで、論文には登場しませんが、クロネッカーのデルタを使いたいと思います。詳細は調べていただきたいと思いますが、簡単に説明するとクロネッカーのデルタは以下のように定義されます。
\begin{equation}
\delta_{i, j} =
\begin{cases}
1 & (i=j) \\
0 & (i\neq j)
\end{cases}
\end{equation}
クロネッカーのデルタを使うと、先の和の変換が可能となります。例えば
\begin{equation}
\sum_{i=1}^n f(\textbf{x}_i)
= \sum_{i=1}^n w_{q(\textbf{x}_i)}
= \sum_{i=1}^n\sum_{j=1}^T w_j \delta_{j, q(\textbf{x}_i)}
= \sum_{j=1}^T \left( \sum_{i=1}^n \delta_{j, q(\textbf{x}_i)} \right) w_j
\end{equation}
といった具合に書き直せます。そして、クロネッカーのデルタを用いて和の取り方を揃えると、
\begin{align}
\mathcal{\tilde{L}}^{(t)} &= \sum_{i=1}^n[g_if_t(\textbf{x}_i) + \frac{1}{2} h_if_t^2(\textbf{x}_i)] + \gamma T + \frac{1}{2}\lambda \sum_{j=1}^T w_j^2 \\
&=\sum_{j=1}^T \left[ \left( \sum_{i=1}^n \delta_{j, q(\textbf{x}_i)}g_i \right) w_j + \frac{1}{2} \left( \sum_{i=1}^n \delta_{j, q(\textbf{x}_i)}h_i + \lambda \right) w_j^2 \right] + \gamma T \tag{9}
\end{align}
と書けます。あとは上式を最小化する$w_j^*$を求めるだけで、これは上式を$w_j$について平方完成して、2乗の部分が0の時を解くだけで容易に求まります。一応結果のみ記載しておくと、
w_j^* = -\frac{\sum_{i=1}^n \delta_{j, q(\textbf{x}_i)}g_i}{\sum_{i=1}^n \delta_{j, q(\textbf{x}_i)}h_i + \lambda} \tag{10}
で、この時の$\mathcal{\tilde{L}}^{(t)}$の値は、
\mathcal{\tilde{L}}^{(t)*}(q)
= -\frac{1}{2} \sum_{j=1}^T \frac{(\sum_{i=1}^n \delta_{j, q(\textbf{x}_i)}g_i)^2}{\sum_{i=1}^n \delta_{j, q(\textbf{x}_i)}h_i + \lambda} + \gamma T \tag{11}
です。
[疑問]
損失関数を2次で近似することを正当化するためには、$f_t$が$\hat{y}^{(t-1)}$と比べて微小であることが必要かと思いますが、これがどのように保障されているのかは分かりませんでした。直感的には、各イテレーションで予測値を改善するため、統計的(?)には$f_t$が$\hat{y}^{(t-1)}$より小さくなりそうな気はしますが、この辺どうなんでしょう。。
木構造の最適化
これでようやく木構造を与えた時のリーフノードの値の最適解(の近似値)が求まりました。あとは各木構造$q$について、$\mathcal{\tilde{L}}^{(t)*}(q)$を比較するだけですが、一般的な問題設定ではすべての木構造を生成して比較しようとすると組合せ爆発してしまうので貪欲法を使って求めます。初めは一つのリーフ(最終的にはルートになる)から初めて、一個ずつ分岐を増やしていき徐々に木を成長させていくやり方です。実はこの方法で分岐を増やす時に先ほどの結果の(11)式が利用できて、(11)式をもとに特定の閾値による分岐を増やした場合の利得(損失関数の減少量)が計算できます。分岐を追加する前の木構造および分岐の追加を検討しているリーフノード�のインデックスをそれぞれ$q$と$I$とし、分岐を追加した後の木構造および追加されたリーフノードのインデックスをそれぞれ$\tilde{q}$と$R, L$(分岐の追加でリーフノードは2個追加されます)と呼ぶことにすると、分岐の追加による損失関数の減少量は、
\mathcal{L}_{split} = \frac{1}{2} \left[
\frac{(\sum_{i=1}^n \delta_{L, \tilde{q}(\textbf{x}_i)}g_i)^2}{\sum_{i=1}^n \delta_{L, \tilde{q}(\textbf{x}_i)}h_i + \lambda}
+
\frac{(\sum_{i=1}^n \delta_{R, \tilde{q}(\textbf{x}_i)}g_i)^2}{\sum_{i=1}^n \delta_{R, \tilde{q}(\textbf{x}_i)}h_i + \lambda}
-
\frac{(\sum_{i=1}^n \delta_{I, q(\textbf{x}_i)}g_i)^2}{\sum_{i=1}^n \delta_{I, q(\textbf{x}_i)}h_i + \lambda}
\right]
- \gamma \tag{12}
と計算できます。右辺の最後の$\gamma$は分岐の追加によってリーフノードの総数が1増えることに対する正則化項に対応しています。これで木構造の作り方も分かりましたね。
終わりに
お疲れ様です!これで一通りの解説が終わりました。今回説明したのは、基本的な考え方のみで、高速化のための諸テクニックなどには触れていませんのでその辺りはご注意いただきたいと思いますが、xgboostが何をやっているかをだいぶ明らかにできたのではないかと思っています。特に全体を通しての流れであったり、手法のモチベーションなどが分かるように書いたつもりですが、分かりづらかったら申し訳ないです。ネットには他にもたくさんの解説記事などがあるのでそちらを読んでいただけたらと思います!また、記述内容に誤り等があればコメントいただけますと助かります。