友人からXGBoostというアルゴリズムがよく使われると聞いたのでXGBoost: A Scalable Tree Boosting Systemという論文を(途中まで)読んでみました。理解を深めるために読んだ部分をまとめていこうと思います。
概要
機械学習初心者がXGBoost: A Scalable Tree Boosting Systemを読んだ備忘録。
1章(INTRODUCTION)に書かれている『XGBoostがいかに汎用性のあるアルゴリズムであるか』や『他のアルゴリズムと比較して何が優れているか』等についてはここでは説明しません。(直訳するだけになってしまうので。)
今回は2章(TREE BOOSTING IN A NUTSHELL)に書かれている内容について、論文で省略されている式変形を丁寧に行いながら自分なりの理解をまとめようと思います。
記号の定義
- $ D := \{ (\mathbf{x}_{i}, y_{i}) \}_{i = 1,2,...,n} \ \ (\mathbf{x}_{i} \in \mathbb{R}^{m}, y_{i} \in \mathbb{R}) $
訓練データ$\mathbf{x}_{i}$とラベル$y_{i}$の集合。
$\mathbf{x}_{i}$から$y_{i}$を予測する"良い"関数を得る事がアルゴリズムの目的となります。
- $ q : \mathbb{R}^{m} \to \{ 1,2,...,T \} \ \ (T \geq 1) $
訓練データの集合${\mathbf{x}_{i}}_{i = 1,2,...,n}$を$T$個の部分集合に分割する決定木。
後に多くの決定木が登場するが、$\ T\ $の値は各$\ q\ $ごとに異なります。
- $w : \{1,2,...,T\} \to \mathbb{R}$
決定木$q$の各葉の重みを定義する関数。
- $f : \mathbb{R}^{m} \to \mathbb{R} $
$q$と$w$の合成を$f$で表します。
- $\hat{y}_{i} := \phi(\mathbf{x}_{i}) := \sum_{k=1}^{K}f_{k}(\mathbf{x}_{i}) = \sum_{k=1}^{K}w_{k}\circ q_{k}(\mathbf{x}_{i})\ \ (K\geq 1)$
$K$個の決定木$q_{k}$とそれらの葉の重みを定義する関数$w_{k}$から得られる予測値。
- $\mathcal{L}(\phi) := \sum_{i=1}^{n}l(\hat{y}_{i},y_{i}) + \sum_{k=1}^{K}\Omega(f_{k})$
損失関数。
この$\mathcal{L}(\phi)$が小さくなるような$\phi$(つまり$q_{k}$と$w_{k}$)を見つける事が目標となります。
第一項の$l$と第二項の$\Omega$については以下で詳しく説明します。
- $l(\hat{y}_{i},y_{i})$
微分可能な凸関数。
$\hat{y}_{i}$がどれだけ$y_{i}$に近いかを表現する関数を$l$として採用します。
XGBoost Documentationによると平均二乗誤差やロジスティック誤差が採用されるらしい。
- $\Omega(f) := \gamma T + \frac{1}{2}\lambda ||w||^{2}\ \ (\gamma, \lambda \in \mathbb{R}_{\geq0})$
モデルの複雑化を防ぐ正則化項。この項を損失関数に加える事によって過学習を防ぐ事ができます。
今回の目標
まずは$K$個の決定木$q_{1},q_{2},...,q_{K}$に対してどのような$w_{1},w_{2},...,w_{K}$をとれば$\mathcal{L}(\phi)$を小さくできるかを考えます。
原論文で$w_{j}$という記号が出てくるが、これは$w_{t}(j)$を表しています。$t$番目の決定木を扱っているという事が伝わりづらいのでここでは$w_{t}(j)$という記号を使う事にします。
ステップ1 (ループの順番)
以下の関数を考えます。
$$\mathbf{obj}(t) := \sum_{i=1}^{n}l(\hat{y}_{i}^{(t)},y_{i}) + \sum_{k=1}^{t}\Omega(f_{k})\ \ (1\leq t \leq K)$$
ここで
$$\hat{y}_{i}^{(t)} := \sum_{k=1}^{t}f_{k}(\mathbf{x}_{i})$$
と定義しています。明らかに
$$\mathbf{obj}(K) = \mathcal{L}(\phi)$$
$$\hat{y}_{i}^{(K)} = \hat{y}_{i}$$
が成り立ちます。
まず$\mathbf{obj}(1)$を小さくする$w_{1}$を求め、次に$\mathbf{obj}(2) - \mathbf{obj}(1)$を小さくする$w_{2}$を求めます。
同様に各$t$において$\mathbf{obj}(t) - \mathbf{obj}(t-1)$を小さくする$w_{t}$を順番に求めていきます。
これによって
$$\mathcal{L}(\phi) = \mathbf{obj}(K) = (\sum_{t=2}^{K}\mathbf{obj}(t) - \mathbf{obj}(t-1)) + \mathbf{obj}(1)$$
を小さくする$w_{1},w_{2},...,w_{K}$を求める事ができる。
ステップ2 (損失関数の式変形)
$\mathbf{obj}(t)$を以下のように式変形できます。
\begin{align}
\mathbf{obj}(t) &= \sum_{i=1}^{n}l(\hat{y}_{i}^{(t)},y_{i})+\sum_{k=1}^{t}\Omega(f_{k}) \\\
&=\sum_{i=1}^{n}l(\hat{y}_{i}^{(t-1)} + f_{t}(\mathbf{x}_{i}), y_{i})+\sum_{k=1}^{t}\Omega(f_{k})\\
&\fallingdotseq\sum_{i=1}^{n}[l(\hat{y}_{i}^{(t-1)}, y_{i}) + g_{i}f_{t}(\mathbf{x}_{i}) + \frac{1}{2}h_{i}f_{t}(\mathbf{x}_{i}) )]+(\sum_{k=1}^{t-1}\Omega(f_{k})) + \Omega(f_{t})\\\
&=\mathbf{obj}(t-1) + \sum_{i=1}^{n}[g_{i}f_{t}(\mathbf{x}_{i}) + \frac{1}{2}h_{i}f_{t}(\mathbf{x}_{i})] + \Omega(f_{t})
\end{align}
ここで
$$
g_{i} := \partial_{\hat{y}^{(t-1)}}l(\hat{y}_{i}^{(t-1)}, y_{i})
$$
$$
h_{i} := \partial_{\hat{y}^{(t-1)}}^{2}l(\hat{y}_{i}^{(t-1)}, y_{i})
$$
と定義しています。よって
$$
\tilde{\mathcal{L}}^{(t)} := \sum_{i=1}^{n}[g_{i}f_{t}(\mathbf{x}_{i}) + \frac{1}{2}h_{i}f_{t}(\mathbf{x}_{i})] + \Omega(f_{t})
$$
と定義すれば、これが$\mathbf{obj}(t) - \mathbf{obj}(t-1)$を近似する関数となる。
$$
I_{j}=\{ i\ |\ q_{t}(\mathbf{x}_{i})=j \}
$$
と定義すると、さらに以下のように式変形できます。
\begin{align}
\tilde{\mathcal{L}}^{(t)} &= \sum_{i=1}^{n}[g_{i}f_{t}(\mathbf{x}_{i}) + \frac{1}{2}h_{i}f_{t}(\mathbf{x}_{i})^{2}] + \Omega(f_{t}) \\
&= \sum_{i=1}^{n}[g_{i}\times(w_{t}(q_{t}(\mathbf{x}_{i}))) + \frac{1}{2}h_{i}\times(w_{t}(q_{t}(\mathbf{x}_{i})))^{2}] + \gamma T + \frac{1}{2}\lambda\sum_{j=1}^{T}w_{t}(j)^{2}\\
&= \sum_{j=1}^{T}[(\sum_{i\in I_{j}}g_{i})w_{t}(j) + \frac{1}{2}(\sum_{i\in I_{j}}h_{i} + \lambda)w_{t}(j)^{2}] + \lambda T
\end{align}
よって$w_{t}$を
$$
w_{t}(j) = \frac{\sum_{i\in I_{j}}g_{i}}{(\sum_{i\in I_{j}}h_{i}) + \lambda}
$$
のように定義すれば$\tilde{\mathcal{L}}^{(t)}$を最小にする$w_{t}$が得られます。
結論
$K$個の決定木$q_{1},...,q_{k}$に対してどのような$w_{1},...,w_{K}$をとれば損失関数$\mathcal{L}(\phi)$が小さくなるかがわかった。
おわりに
ここまでの内容を踏まえてどのような$q_{1},...,q_{k}$をとれば良いかという話が続きに書いてあります。
余裕があれば続きも書こうと思います。