LoginSignup
3
1

ツリーの4つの可視化方法

Last updated at Posted at 2024-05-28

image.png

ツリー構造の4つの可視化方法

ランダムフォレストやXGBoost、決定木分析をした時にモデルのツリー構造を確認します。決定木の大きさやデータによって描画の仕方に使い分けができるので、それぞれまとめました。

💡この記事で紹介すること

  1. tree.export_textでテキスト描画
  2. tree.plot_treeを使う
  3. graphvizを使う
  4. dtreevizでヒストグラムや円グラフも追加

ツリーの可視化

KaggleのSpaceship Titanicを題材にします。

このコンペは乗客が異次元に転送されたかどうかを推論する2値分類問題です。
説明変数は13個で、6個がnumerical、7個がcategoricalです。
今回はnumerical変数のみを説明変数として扱います。

前処理はこちら
[in]
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier 
from sklearn.model_selection import train_test_split

df = pd.read_csv('/kaggle/input/spaceship-titanic/train.csv')

target = 'Transported'
X = df.drop([target], axis=1)
X = X[X.loc[:,X.dtypes=='object'].columns]
X = X.fillna(X.mean())
y = df[target].astype(int)

決定木分析でつくったモデルを使います。

[in]
feature_names = list(X.columns)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
clf = DecisionTreeClassifier(max_depth=4,random_state=1234)
model= clf.fit(X_train, y_train)

tree.export_textでテキスト描画

モデルをつくったsklearn.treeには3つの出力方法があります。
ドキュメントはこちら

1つ目は、テキストによる描画です。

こんな時に使う
さっとツリーを確認したいときやコマンドラインで出力したいときに適しています

[in]
from sklearn import tree
print(tree.export_text(model,feature_names=feature_names))

[out]
|--- RoomService <= 0.50
|   |--- Spa <= 2.50
|   |   |--- VRDeck <= 332.00
|   |   |   |--- Age <= 12.50
|   |   |   |   |--- class: 1
|   |   |   |--- Age >  12.50
|   |   |   |   |--- class: 1
|   |   |--- VRDeck >  332.00
|   |   |   |--- FoodCourt <= 3657.00
|   |   |   |   |--- class: 0
|   |   |   |--- FoodCourt >  3657.00
|   |   |   |   |--- class: 1
|   |--- Spa >  2.50
|   |   |--- FoodCourt <= 2507.50
|   |   |   |--- Spa <= 314.57
|   |   |   |   |--- class: 0
|   |   |   |--- Spa >  314.57
|   |   |   |   |--- class: 0
|   |   |--- FoodCourt >  2507.50
|   |   |   |--- Spa <= 2523.50
|   |   |   |   |--- class: 1
|   |   |   |--- Spa >  2523.50
|   |   |   |   |--- class: 0
|--- RoomService >  0.50
|   |--- RoomService <= 365.50
|   |   |--- Spa <= 422.50
(省略)

階層が深いツリーを扱う場合はmax_depthを設定して大項目だけ見やすくするのも良いです。

tree.export_text(model,feature_names=feature_names, max_depth=3)

tree.plot_treeを使う

plot_treeは、export_text同様sklearn.treeにあります。
こちらは画像としてプロットを生成します。

こんな時に使う
ツリーの全体像を図でさっと確認したいときに便利です。ツリーが複雑ではない場合が望ましいです

[in]
fig = plt.figure(figsize=(32,20))
_ = tree.plot_tree(model,feature_names=feature_names, fontsize=10,filled=True)

image.png

画面のサイズに収まるように表示するので、サイズ調整が必要になることもがあります。
image.png

figureのメソッドで画像を保存します。

fig.savefig('tree.png')

graphvizを使う

こちらもsklearn.treeにあります。graphvizはdot言語というグラフ構造を記述する言語によって書かれています。plot_treeとは描画の挙動が異なる点に注意です。

こんな時に使う
ツリーが大きいときやパラメーターを細かく見たいときに適しています

[in]
import graphviz
dot_data = tree.export_graphviz(model, out_file=None, 
                                feature_names=feature_names,  
                                class_names=['True','False'],
                                filled=True)
graph = graphviz.Source(dot_data, format="png") 
graph

tree (1).png

plot_treeと異なりオーバーフローが適用されます。大規模なツリーであったり、表示項目が多くても、表示サイズを気にしなくてよいです。

image.png

保存はrenderを使います。事前にフォーマットとしてpngを指定しているので拡張子の記述は不要です。

graph.render('tree')

importが上手くいかない場合はドキュメントを確認してください。

dtreevizでヒストグラムや円グラフも追加

最後はdtreevizです。dtreevizはツリー構造を説明するためのpythonライブラリで、scikit-learn、XGBoost、Spark MLlib、LightGBM、Tensorflowをサポートしています。

こんな時に使う
入力したデータをどのように分岐させたかを確認したいとき

[in]
!pip install dtreeviz
import dtreeviz as dt

viz = dt.model(model,X_train=X_test, y_train=y_test,
                target_name="Transported",
                feature_names=list(feature_names))

viz.view()

image.png

データのヒストグラムに対してどこに分岐の基準点を置いたかをヒストグラムで表し、どの程度の割合でそれぞれのクラスが混ざっているかを円グラフで表しています。

image.png

また、保存はsvg形式です。

viz.view().save('tree4.svg')

まとめ

dtreevizは性能を確認したり、説明するのにとても便利です。個人的には、実際の解析では方向性が見えるまではさらっとテキストで出力することも多いです。

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1