ツリー構造の4つの可視化方法
ランダムフォレストやXGBoost、決定木分析をした時にモデルのツリー構造を確認します。決定木の大きさやデータによって描画の仕方に使い分けができるので、それぞれまとめました。
💡この記事で紹介すること
- tree.export_textでテキスト描画
- tree.plot_treeを使う
- graphvizを使う
- dtreevizでヒストグラムや円グラフも追加
ツリーの可視化
KaggleのSpaceship Titanicを題材にします。
このコンペは乗客が異次元に転送されたかどうかを推論する2値分類問題です。
説明変数は13個で、6個がnumerical、7個がcategoricalです。
今回はnumerical変数のみを説明変数として扱います。
前処理はこちら
[in]
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
df = pd.read_csv('/kaggle/input/spaceship-titanic/train.csv')
target = 'Transported'
X = df.drop([target], axis=1)
X = X[X.loc[:,X.dtypes=='object'].columns]
X = X.fillna(X.mean())
y = df[target].astype(int)
決定木分析でつくったモデルを使います。
[in]
feature_names = list(X.columns)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
clf = DecisionTreeClassifier(max_depth=4,random_state=1234)
model= clf.fit(X_train, y_train)
tree.export_textでテキスト描画
モデルをつくったsklearn.treeには3つの出力方法があります。
ドキュメントはこちら。
1つ目は、テキストによる描画です。
こんな時に使う
さっとツリーを確認したいときやコマンドラインで出力したいときに適しています
[in]
from sklearn import tree
print(tree.export_text(model,feature_names=feature_names))
[out]
|--- RoomService <= 0.50
| |--- Spa <= 2.50
| | |--- VRDeck <= 332.00
| | | |--- Age <= 12.50
| | | | |--- class: 1
| | | |--- Age > 12.50
| | | | |--- class: 1
| | |--- VRDeck > 332.00
| | | |--- FoodCourt <= 3657.00
| | | | |--- class: 0
| | | |--- FoodCourt > 3657.00
| | | | |--- class: 1
| |--- Spa > 2.50
| | |--- FoodCourt <= 2507.50
| | | |--- Spa <= 314.57
| | | | |--- class: 0
| | | |--- Spa > 314.57
| | | | |--- class: 0
| | |--- FoodCourt > 2507.50
| | | |--- Spa <= 2523.50
| | | | |--- class: 1
| | | |--- Spa > 2523.50
| | | | |--- class: 0
|--- RoomService > 0.50
| |--- RoomService <= 365.50
| | |--- Spa <= 422.50
(省略)
階層が深いツリーを扱う場合はmax_depth
を設定して大項目だけ見やすくするのも良いです。
tree.export_text(model,feature_names=feature_names, max_depth=3)
tree.plot_treeを使う
plot_treeは、export_text同様sklearn.treeにあります。
こちらは画像としてプロットを生成します。
こんな時に使う
ツリーの全体像を図でさっと確認したいときに便利です。ツリーが複雑ではない場合が望ましいです
[in]
fig = plt.figure(figsize=(32,20))
_ = tree.plot_tree(model,feature_names=feature_names, fontsize=10,filled=True)
画面のサイズに収まるように表示するので、サイズ調整が必要になることもがあります。
figure
のメソッドで画像を保存します。
fig.savefig('tree.png')
graphvizを使う
こちらもsklearn.treeにあります。graphvizはdot言語というグラフ構造を記述する言語によって書かれています。plot_treeとは描画の挙動が異なる点に注意です。
こんな時に使う
ツリーが大きいときやパラメーターを細かく見たいときに適しています
[in]
import graphviz
dot_data = tree.export_graphviz(model, out_file=None,
feature_names=feature_names,
class_names=['True','False'],
filled=True)
graph = graphviz.Source(dot_data, format="png")
graph
plot_treeと異なりオーバーフローが適用されます。大規模なツリーであったり、表示項目が多くても、表示サイズを気にしなくてよいです。
保存はrender
を使います。事前にフォーマットとしてpngを指定しているので拡張子の記述は不要です。
graph.render('tree')
importが上手くいかない場合はドキュメントを確認してください。
dtreevizでヒストグラムや円グラフも追加
最後はdtreevizです。dtreevizはツリー構造を説明するためのpythonライブラリで、scikit-learn、XGBoost、Spark MLlib、LightGBM、Tensorflowをサポートしています。
こんな時に使う
入力したデータをどのように分岐させたかを確認したいとき
[in]
!pip install dtreeviz
import dtreeviz as dt
viz = dt.model(model,X_train=X_test, y_train=y_test,
target_name="Transported",
feature_names=list(feature_names))
viz.view()
データのヒストグラムに対してどこに分岐の基準点を置いたかをヒストグラムで表し、どの程度の割合でそれぞれのクラスが混ざっているかを円グラフで表しています。
また、保存はsvg形式です。
viz.view().save('tree4.svg')
まとめ
dtreevizは性能を確認したり、説明するのにとても便利です。個人的には、実際の解析では方向性が見えるまではさらっとテキストで出力することも多いです。