概要
この記事では、私がつくったライブラリdtreeplt
について解説します。
決定木可視化界隈に革命を起こすことが目的です。
↑こんなのをmatplotlibとnumpyのみで描画します。すごいぞ。
追記(2020/01/05)
dtreepltを公開したすぐ後に出たsklearn0.21にて、 sklearn.tree.plot_tree
が実装されてました。。
dtreepltと同様に、Graphvizを使わずにmatplotlibで決定木を描画できるようになりました。
dtreepltに独自性がないわけではないのですが、サクッと決定木みたいなーってときにはplot_treeを使ったほうが楽でしょう。
サンプルコードと出力結果は下記のとおりです。(githubのREADMEにも載せてます)
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
%matplotlib inline
iris = load_iris()
model = DecisionTreeClassifier()
model.fit(iris.data, iris.target)
fig = plt.figure(figsize=(15, 8))
ax = fig.add_subplot()
plot_tree(model, feature_names=iris.feature_names, ax=ax, class_names=iris.target_names, filled=True);
決定木可視化における課題
データ分析各位、決定木は好きですか?わたしは好きです。
データ分析初学者のファーストステップとして、IrisとかTitanicをscikit-learnのDecisionTreeClassifierで分類してexport_graphvizで可視化することが多いですよね。
初学者に限らず決定木による可視化は解釈性が高く、よく使われることと思います。
ただ、可視化の際に使われるGraphviz、インストールがめちゃめちゃ難易度高いことに定評があります。
Windows, Mac, Linuxで変わってくるのも辛いポイントですね。
「データ分析やるぞ!」と意気込みAnacondaを入れたPython初心者が、Graphvizのインストールで躓いてしまうのは悲しいです。
そこで、Graphvizを使わずにnumpyとmatplotlib(と、もちろんscikit-learn)のみで決定木を可視化するライブラリdtreeplt
をつくりました。
Githubリポジトリはこちら
インストール
これだけです。最新のAnaconda環境ならなんの問題もないと思います。
pip install dtreeplt
※2019/06/16追記
ノードが重なって表示される致命的なバグを修正したので、gitから最新版をinstallするのが一番確実です。
pip install git+https://github.com/nekoumei/dtreeplt.git
upgradeの場合は下記のようにしてください
pip install git+https://github.com/nekoumei/dtreeplt.git -U
PyPIも一応更新済なんですが、今後もgitが常に最新となるのでこちらを推奨します
numpy, matplotlib, scikit-learnが入っていない場合は、事前に↓みたいにして入れてください。
pip install numpy matplotlib scikit-learn
また、これらのバージョンが古くてうまくいかない場合は↓みたいにupgradeしてください。
# numpyのバージョンが古かった場合
pip install numpy --upgrade
Pythonのバージョンは3.6.X
以上を想定しています。
以上で準備は終わりです。
Quick Start
Jupyter Notebookで、3行で確認できます。
本当は、dtreeplt
インスタンス作成時にパラメータで学習済決定木モデルを渡す必要がありますが、モデルが渡されなかった場合、自動的にIrisデータセットからモデルを作成して可視化します。
ほんとうのつかいかた
このライブラリは、アヤメの分類器可視化専用ではないので任意の決定木モデルを渡して可視化することができます。下記のとおりパラメータを渡してあげてください。
(結局アヤメの分類をしている)
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from dtreeplt import dtreeplt
iris = load_iris()
model = DecisionTreeClassifier()
model.fit(iris.data, iris.target)
dtree = dtreeplt(
model=model,
feature_names=iris.feature_names,
target_names=iris.target_names
)
fig = dtree.view()
#if you want save figure, use savefig method in returned figure object.
#fig.savefig('output.png')
dtree.view()はmatplotlibのfigureオブジェクトをreturnするので、savefigしたいときはそれを使えばOKです。
また、fontはmatplotlibのデフォルトに依存するので日本語が出ない!とかがあったら matplotlib 日本語
とかでググれば解決すると思います。
Future Works
そもそも、Graphvizを使わずに可視化したい!と思ったきっかけがdtreevizというライブラリを使おうとしたときでした。
dtreevizは決定木可視化をめっちゃ良いかんじにしてくれるライブラリです。語彙力がないですね。
このissueと似たようなエラーを踏んでしまって、自分のMac環境では未だ使えていません。(Linux環境では使えました)
そこで、dtreevizをGraphvizなしでやりたい!という思いからdtreepltをつくりはじめました。
なので、いずれはdtreevizをmatplotlibで表現することが目標です。のんびり作っていくので気長にお待ちください。
余談
その1
dtreepltのキモは、決定木の情報をいかにして取ってくるかです。実は決定木オブジェクトにはすべての情報が入っているので、それをとってくるだけです。
このあたりの知見は1年くらいまえにやった3-Dインタラクティブ決定木のときに雑ですが解説しています。
興味があるひとはそちらもご覧いただければ。
その2
個人的には、Graphviz自体は大好きです。パス分析も容易にできるので素晴らしいツールだと思います。
ただ、決定木を可視化するためだけにPython外のツールを導入するのはちょっとめんどくさいんですよね。
決定木のように構造が単純であればmatplotlibだけでもじゅうぶん描画できます。