決定木は人間にとって判断基準がわかりやすい判別・回帰の手法です。
そのため判断基準を可視化したくなることが多いのですが、dtreeviz というとてもわかりやすい可視化ライブラリが公開されたそうなので、決定木系の RandomForest で簡単に試してみます。
決定木について知りたい方は下記の記事が参考になります。
https://qiita.com/3000manJPY/items/ef7495960f472ec14377
実行環境
Google Colaboratory を使用します。
まず決定木で試す
下記のページを参考にさせていただいており、そのままですが決定木を実行してみます。
https://qiita.com/calderarie/items/e4321bff95ac3042601b
まず dtreeviz をインストールします。
! pip install dtreeviz
scikit-learn 付属の iris のデータセットを使用します。3種類の花ごとの花弁の長さなどのデータとなっており、特徴をみて種類を分類できるかを試すことができます。
from sklearn.datasets import load_iris
from sklearn import tree
from dtreeviz.trees import dtreeviz
iris = load_iris()
clf = tree.DecisionTreeClassifier(max_depth=2)
clf.fit(iris.data, iris.target)
これで決定木を得ることができました。
dtreeviz で可視化してみます。
viz = dtreeviz(
clf,
iris.data,
iris.target,
target_name='variety',
feature_names=iris.feature_names,
class_names=[str(i) for i in iris.target_names],
)
viz
花弁の幅を条件にして判別が行われ、setosa, versicolor と分類が定められる決定木の内容が可視化されます。
RandomForest の場合
RandomForest は複数の決定木を使用して、学習データ以外のデータに対する精度を向上させる手法です。
以下が参考になります。
といっても実行するだけなら簡単です。
from sklearn.ensemble import RandomForestClassifier
clf = RandomForestClassifier(n_estimators=150)
clf.fit(iris.data, iris.target)
複数の決定木がありますが、今回は先頭の決定木を可視化します。
estimators = clf.estimators_
viz = dtreeviz(
estimators[0],
iris.data,
iris.target,
target_name='variety',
feature_names=iris.feature_names,
class_names=[str(i) for i in iris.target_names],
)
viz
出力したい決定木を決めれば同様に出力できます。