11
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

scikit-learnのDecisionTreeClassifier基本使い方: 枝刈りと決定木描画

Posted at

scikit-learnのDecisionTreeClassifierの基本的使い方を解説します。
訓練、枝刈り、評価、決定木描画をしていきます。

環境

Python3.7.13で1Google Colaboratory上で動かしています。Google Colabプリインストールされているパッケージはそのまま使っています。
最近気づいたのですがscikit-learnはPython3.7ではもう更新しなくなっていますね。

Package Version 備考
scikit-learn 1.0.2 Google Colabプリインストール
matplotlib 3.2.2 Google Colabプリインストール
numpy 1.21.6 Google Colabプリインストール
pandas 1.3.5 Google Colabプリインストール
dtreeviz 1.3.7 手動インストール

プログラム

1. Import

パッケージインポート。

from dtreeviz.trees import dtreeviz
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_classification
from sklearn.metrics import classification_report, ConfusionMatrixDisplay, RocCurveDisplay, PrecisionRecallDisplay
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree

2. データ生成

make_classification関数を使って5特徴量のデータを生成。

FEATURES = ['X0', 'X1', 'X2', 'X3', 'X4']
def make_df():
    X, y = make_classification(n_samples=1100, 
                               n_features=5, n_redundant=0)
    X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=1000)
    return X_train, X_test, y_train, y_test

X_train, X_test, y_train, y_test = make_df()

make_classification関数の使い方はこちらが詳しい。

3. 枝刈り準備

cost_complexity_pruning_path関数を使って、枝刈りのためのパス計算をします。DecistionTreeは木を深くした後に枝刈りをして最適化するようです。深さは5に制限しておきますが、これはケースバイケースです。
※仕事で使ったときには、この前にmax_depthcriterion等の値をハイパーパラメータ探索しました。

def get_ccp_alphas(X_train, y_train):
    path = DecisionTreeClassifier(max_depth=5).cost_complexity_pruning_path(X_train, y_train)
    display(pd.DataFrame(path))

    _, ax = plt.subplots(figsize=(10, 4))

    # 最終行は1ノードだけの決定木なので出力は無駄
    ax.plot(path.ccp_alphas[:-1], path.impurities[:-1], marker="o", drawstyle="steps-post")
    ax.set_xlabel("effective alpha")
    ax.set_ylabel("total impurity of leaves")
    ax.set_title("Total Impurity vs effective alpha for training set")
    plt.show()

    return path

path = get_ccp_alphas(X_train, y_train)

結果は以下の値です。ccp_alphasは正則化パラメータのようなもので、値が大きいほど決定木がシンプルになります。ここでは最終行17はルートノードのみの状態です。
image.png
image.png

枝刈りの詳細は、以前「木の剪定アルゴリズム」としてはじめてのパターン認識」で学習しました(お勧め)。

他にも以下のページで復習しました。

4. 訓練

ccp_alphasごとに訓練します。

def train_with_alphas(X_train, y_train, ccp_alphas):
    clfs = []
    for i, ccp_alpha in enumerate(ccp_alphas):
        clf = DecisionTreeClassifier(max_depth=5, ccp_alpha=ccp_alpha)
        clf.fit(X_train, y_train)
        clfs.append(clf)
        print(f'Finished: {i+1}/{len(ccp_alphas)}')
    # ノード数は枝刈りを最後までやった結果なので必ず1
    print(f"最終決定木のノード数: {clfs[-1].tree_.node_count} with ccp_alpha: {ccp_alphas[-1]}")
    return clfs

clfs = train_with_alphas(X_train, y_train, path.ccp_alphas)

5. スコア算出

訓練したモデルごとに訓練および評価データに対するスコアを算出します。

train_scores = [clf.score(X_train, y_train) for clf in clfs[:-1]]
test_scores = [clf.score(X_test, y_test) for clf in clfs[:-1]]

6. 枝刈り単位の情報参照

枝刈りした状態ごとの以下の情報をグラフで右側に出力。

  1. 不純度(「3. 枝刈り準備」と同じグラフ)
  2. ノード数
  3. 木の深さ
  4. 訓練・評価スコア
def output_prune_result(path, clfs, train_scores, test_scores):
    node_counts = [clf.tree_.node_count for clf in clfs]
    depth = [clf.tree_.max_depth for clf in clfs]
    fig=plt.figure(figsize=(10,12))
    
    ax_11=fig.add_subplot(121)
    ax_11.axis('tight')
    ax_11.axis('off')
    tab = ax_11.table(cellText=np.round(pd.DataFrame(path).values, decimals=5),
                loc='upper left',
                colLabels=dir(path),
                rowLabels=np.arange(len(path.ccp_alphas)),
                colColours =["#EEEEEE"] * 2,
                rowColours =["#EEEEEE"] * len(path.ccp_alphas))
    tab.auto_set_font_size(False)
    tab.set_fontsize(15)
    tab.scale(1,2)

    ax_21=fig.add_subplot(422)
    ax_21.plot(path.ccp_alphas[:-1], path.impurities[:-1], marker="o", drawstyle="steps-post")
    ax_21.set_xlabel("effective alpha")
    ax_21.set_ylabel("total impurity of leaves")
    ax_21.set_title("Total Impurity vs effective alpha for training set")

    ax_22=fig.add_subplot(424)
    ax_22.plot(path.ccp_alphas[:-1], node_counts, marker="o", drawstyle="steps-post")
    ax_22.set_xlabel("alpha")
    ax_22.set_ylabel("number of nodes")
    ax_22.set_title("Number of nodes vs alpha")

    ax_23=fig.add_subplot(426)
    ax_23.plot(path.ccp_alphas[:-1], depth, marker="o", drawstyle="steps-post")
    ax_23.set_xlabel("alpha")
    ax_23.set_ylabel("depth of tree")
    ax_23.set_title("Depth vs alpha")

    ax_24=fig.add_subplot(428)
    ax_24.set_xlabel("alpha")
    ax_24.set_ylabel("accuracy")
    ax_24.set_title("Accuracy vs alpha for training and testing sets")
    ax_24.plot(path.ccp_alphas[:-1], train_scores, marker="o", label="train", drawstyle="steps-post")
    ax_24.plot(path.ccp_alphas[:-1], test_scores, marker="o", label="test", drawstyle="steps-post")
    ax_24.legend()    
    
    fig.tight_layout()
    plt.show()

output_prune_result(path, clfs[:-1], train_scores, test_scores)

一番下のグラフ(訓練・評価スコア)で11番目が最も評価スコアが高いので、以後は11番目の訓練モデルを使っていきます(Pythonのインデックスで0から数えて11で、1から数えたら12番目)。この後、11番目の訓練モデルを使っていくのですが、目で選ぶ以外に良い方法ないのでしょうか。
image.png

7. 分類評価指標表示

分類評価指標および特徴量重要性を表示します。

def output_graphs(clf, X_test, y_test):

    y_pred_proba = clf.predict_proba(X_test)
    y_pred = clf.predict(X_test)
    print(classification_report(y_test, y_pred))
    
    fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(12, 8), constrained_layout=True)
    fig.subplots_adjust(wspace=0.5, hspace=0.5)

    # Confusion Matrix 出力
    ConfusionMatrixDisplay.from_predictions(y_test, y_pred, ax=axes[0, 0])

    # Feature Importance 出力
    importances = pd.DataFrame({'Importance':clf.feature_importances_}, index=FEATURES)
    importances.sort_values('Importance', ascending=False).head(10).sort_values('Importance', ascending=True).plot.barh(ax=axes[0, 1], grid=True)

    # ROC曲線出力
    RocCurveDisplay.from_predictions(y_test, y_pred_proba[:,1], ax=axes[1, 0])
    axes[1, 0].set_title('ROC(Receiver Operating Characteristic) Curve')

    # 適合率-再現率グラフ出力
    PrecisionRecallDisplay.from_predictions(y_test, y_pred_proba[:,1], ax=axes[1, 1])

    plt.show()


output_graphs(clfs[11], X_test, y_test)

classification_report関数の結果。

classification_report
              precision    recall  f1-score   support

           0       0.90      1.00      0.95        55
           1       1.00      0.87      0.93        45

    accuracy                           0.94       100
   macro avg       0.95      0.93      0.94       100
weighted avg       0.95      0.94      0.94       100

混合行列、特徴量重要性、ROC、PR曲線を出力しています。
image.png

8. 決定木描画

scikit-learnのplot_tree関数とdtreeviz関数を使って決定木の描画。

CLASS_NAMES = ['Class 0', 'Class 1']
def output_trees(clf, X_train, y_train):
    plt.figure(figsize=(18,7))
    plot_tree(clf, filled=True, feature_names=FEATURES, class_names=CLASS_NAMES, fontsize=9)
    plt.show()
    viz = dtreeviz(
        clf,
        X_train, 
        y_train,
        feature_names=FEATURES,
        class_names=CLASS_NAMES,
        target_name='y', 
        fontname='DejaVu Sans' #fontname='Hiragino Sans'
    ) 
    display(viz)

output_trees(clfs[11], X_train, y_train)

plot_tree(scikit-learn)

シンプルでわかりやすい決定木です。赤がクラス0で青がクラス1に分類されたノードです。色が濃いほど確信度が高いです。

  1. 条件分岐: Trueの場合は左に分岐
  2. 不純度: ノードの不純度。今回はgini係数。
  3. サンプル数: ノートのサンプル数
  4. クラスごとサンプル数: 配列でクラス単位のサンプル数表示(重み付け学習をするとサンプル数ではなくなるので要注意)
  5. 分類: 分類されたクラス

image.png

dtreeviz

dtreevizという決定木描画パッケージは、ヒストグラムや円グラフで表現してくれます。この方がわかりやすいことも多いですね。fontnameをDejaVu Sansにしたのは、google colab上で指定しないと警告が出たからです。
Macで日本語表示した時はパラメータfontnameをHiragino Sansにして、パッケージjapanize-matplotlibを使いました。

image.png

9. おまけ: 推論と決定木描画

特定のレコードに対する推論時の決定木描画をdtreevizでします。ついでに推論結果も出しました。

def check_prediction(clf, X_train, y_train):

    viz = dtreeviz(
        clf,
        X_train, 
        y_train,
        feature_names=FEATURES,
        class_names=CLASS_NAMES,
        fontname='DejaVu Sans',
        #orientation ='LR',  # left-right orientation
        X=X_train[0])  # need to give single observation for prediction
              
    display(viz)
    print(f'Predicted Probability is: {clf.predict_proba(X_train[0:1,])}')


check_prediction(clfs[11], X_train, y_train)

どんなパスを経て予測をしたかがわかります。
image.png

predict_proba
Predicted Probability is: [[0.9702381 0.0297619]]

おまけ

いくつか疑問に思ったことを調べました。

Probability

predict_proba関数で出てくるprobabilityは、どんな値かを調べました。リーフノードにおける訓練時のサンプル数比率です。つまり、リーフノードは全部クラス1だった場合には、100%となります。

正確にはサンプル数比率ではなく、valueの比率です。valueは重み付け学習するとサンプル数でなくなります。

image.png

重み付け学習

学習時にclass_weightを使うことで重み付け学習をすることができます。

  • {0:1, 1:10}: クラス0に対してクラス1を10倍にしたい場合の指定
  • balanced: 訓練データからクラス間のバランスを自動計算する場合

同じデータで{0:1, 1:10}と重み付け学習して決定木描画しました。valueの値がsamplesと異なっているのがわかります。例えば赤枠で囲った左下のリーフノードはsample数が19にも関わらず、valueが[14, 50]となっています。これはクラス1のサンプル数5を10倍して50とカウントしているためです。そして、結果としてこのリーフノードに来た場合にはクラス1と判断されます。
image.png

参考

DecisionTreeClassifierについてこちらの記事が非常に詳しいです。

11
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?