0
1

sklearn.tree.DecisionTreeRegressorのデフォルト実装

Posted at

この記事の内容

scikit-learn の Decision Tree について説明します。
デフォルトの学習/推論/特徴量重要度算出アルゴリズムはできるだけ詳しく説明しましたが、それ以外の内容(パラメータやアトリビュート、使用上の注意など)にはあまり触れていません。
より詳しく知りたい方は末尾の参考文献(scikit-learnのドキュメント)をご覧ください。

木構造

まず用語の説明をします。
文献によって使われる用語は微妙に異なりますが、この記事では以下の用語で説明をします。

木(有向木)とはノードとパスを持つ以下のような構造の事を指します。
スクリーンショット 2024-01-23 21.36.35.png
根から順番に枝分かれをしていき、末端のノードを葉と呼びます。
各パスに対して、パスの根本を親ノード、行き先を子ノードと呼びます。

木構造をAIモデルに応用する場合、以下のように説明変数に対する条件の真偽で学習データの分割をします。分割が終わった後、それぞれの葉に対応する予測値を算出します。
スクリーンショット 2024-01-28 16.43.55.png
推論時には、入力されたサンプルがどの葉に入るかを各条件に従って計算し、入力されたサンプルが入った葉の予測値をモデルの予測値として返します。

sklearn.tree.DecisionTreeRegressor

sklearn.tree.DecisionTreeRegressor のデフォルトのアルゴリズム(学習/推論/特徴量重要度算出)を解説します。
説明に使うデータセットの説明と学習済みモデルの可視化をした後、「推論 → 特徴量重要度算出 → 学習」の順番でアルゴリズムの説明をします。

データセット

以下のデータセットを使って学習してみます。

df = pd.DataFrame({'explanatory1': (0, 1, 2, 3, 4, 5, 6, 7),
                  'explanatory2': (3, 6, 5, 4, 7, 0, 10, 1),
                  'objective': (5, 1, 5, 10, 4, 2, 9, 3)})
df

スクリーンショット 2024-01-26 18.35.24.png

学習して可視化

学習をして可視化してみます。

完全にデフォルト設定のまま学習をすると全ての葉に含まれるサンプル数が 1 になってしまうので、各ノードのサンプル数が 2 以上になるオプションをつけて学習します。

# 学習
regressor = DecisionTreeRegressor(min_samples_leaf=2)
regressor = regressor.fit(df.loc[:,["explanatory1", "explanatory2"]].values, df["objective"].values)

# 可視化
plt.figure(figsize=(45, 10))
plot_tree(regressor, feature_names=df.columns, node_ids=True, filled=True)
plt.show()

スクリーンショット 2024-01-26 18.41.02.png
scikit-learnにはplot_treeという関数が用意されていて、学習済みDecisionTreeを可視化できます。

上記の画像に含まれている情報から、このモデルがどのような推論結果を返すのか、各説明変数の特徴量重要度がどうなるか、を全て読み取る事ができます。

まず各ノードに書いてある式を説明します。
スクリーンショット 2024-01-28 17.27.00.png

推論アルゴリズム

以下のサンプルの推論が上記の学習済みモデルでどのように行われるか観察します。

説明変数
explanatory1 1
explanatory2 2

推論の流れは以下のようになります。
スクリーンショット 2024-01-28 20.53.56.png
node #0 では「explanatory2 ≤ 2.0」が条件になっており、条件を満たしているため左の node #1 に移ります。

node #1 は葉なので対応する予測値(value = 2.5)を返します。

実際に以下のコードで予測値を確認すると 2.5 になっています。
スクリーンショット 2024-04-14 17.42.39.png

もう一回、値を変えて観察してみます。

説明変数
explanatory1 2
explanatory2 3

この場合の推論の流れは以下のようになります。
スクリーンショット 2024-01-29 15.43.54.png
node #0 では「explanatory2 ≤ 2.0」が条件になっており、条件を満たしていないため右の node #2 に移ります。

node #2 では「explanatory1 ≤ 2.5」が条件になっており、条件を満たしているため左の node #3 に移ります。

node #3 は葉なので対応する予測値(value = 3.667)を返します。

実際に以下のコードで予測値を確認すると 3.667 になっています。
スクリーンショット 2024-04-14 17.42.46.png

特徴量重要度算出アルゴリズム

ある説明変数の特徴量重要度 (impurity-based feature importances) は、その説明変数を分割の条件に使っている各ノードの samples と mse から以下のように計算できます。

\sum_{i}(Samples(i)\times MSE(i) - Samples_{LeftChild}(i)\times MSE_{LeftChild}(i) - Samples_{RightChild}(i)\times MSE_{RightChild}(i))

ここで $i$ はその説明変数を分割の条件に使っている各ノードの ID 全体を走り、各記号は以下のように定義します。(node #$i$ は葉ではない事に注意してください)

\displaylines{
Samples(i) = node \ \#iのsamples\\
MSE(i) = node \ \#iのmse\\
Samples_{LeftChild}(i) = node \ \#iの左側子ノードのsamples\\
MSE_{LeftChild}(i) = node \ \#iの左側子ノードのmse\\
Samples_{RightChild}(i) = node \ \#iの右側子ノードのsamples\\
MSE_{RightChild}(i) = node \ \#iの右側子ノードのmse
}

つまり各ノードに対して $samples \times mse \ (= squared \ error)$ を不純度(impurity)とみなして、それをどれだけ下げられたかで特徴量重要度を計算しているイメージです。

学習済みモデルに保存されている特徴量重要度は上記の式を正規化したものになっており、上で学習した DecisionTreeRegressor の場合は以下のように計算できます。
スクリーンショット 2024-04-14 17.55.25.png
学習済みモデルの保存されている特徴量重要度は以下のコードで確認でき、上で計算したものと一致している事がわかります。
(plot_treeで出力される数値が小数点第四位くらいまでなので、それ未満の値はズレています。)
スクリーンショット 2024-04-14 17.55.45.png

学習アルゴリズム

新しい分割を定義するためには葉 $i$ 、説明変数 $j$ , 閾値 $t$ を決める必要があります。

葉 $i$ に対してDecision Tree (のデフォルト設定)では以下の式が非負になる$(j, t)$が存在する場合に分割をします。(おそらくサンプル数が 0 のノードは作成されません。)

Samples(i)\times MSE(i) - Samples_{LeftChild}(i)\times MSE_{LeftChild}(i) - Samples_{RightChild}(i)\times MSE_{RightChild}(i)

$(j, t)$は $j ≤ t$ を分割の条件として子ノードを定義したときに、以下の値が最小になるように選びます。

Sample_{LeftChild}(i) \times MSE_{LeftChild}(i) + Sample_{RightChild}(i) \times MSE_{RightChild}(i)

特徴量重要度算出と同じように $samples \times mse \ (= squared \ error)$ を不純度(impurity)とみなして、不純度が最小になるように分割を繰り返すイメージです。

参考文献

sklearn.tree.DecisionTreeRegressor
1.10. Decision Trees

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1