Pythonで実装された機械学習ライブラリのscikit-learnは様々なアルゴリズムを簡単に試せることからしばしば利用されています。花形と言えばTensorFlowやPyTorchですが、お堅い現場ではなかなか使えません。。。そんなscikit-learnで教師有り学習の代表的な手法「決定木」を学習後に描画時に便利な関数がVersion0.21.xから実装されたので従来のGraphVizを用いる方法と比較しつつ試してみました。
従来の可視化方法: GraphVizを利用
従来はGraphVizという別のライブラリをインストールして、利用していました。結構手間が掛かります。。。
brew install graphviz
pip install graphviz
sudo apt install -y graphviz
pip install graphviz
import graphviz
from sklearn import tree
from sklearn.datasets import load_iris
iris = load_iris()
clf = DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)
graph = graphviz.Source(tree.export_graphviz(clf, class_names=iris.feature_names, filled=True))
graph
実行結果
実行結果はgraph.render('decision_tree')
を実行するとPDFとして保存できます。
tree.plot_treeを利用
tree.plot_tree
を用いてGraphVizを利用して描画した物と同様の図を描画してみます。scikit-learnのtreeモジュールに格納されている為、追加のインストールは不要です。(filled
オプションはデフォルトではFalseですが、Trueにすると彩色されます)
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree
iris = load_iris()
clf = DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)
iris = load_iris()
plt.figure(figsize=(15, 10))
plot_tree(clf, feature_names=iris.feature_names, filled=True)
plt.show()
実行結果
GraphVizを用いた方法と同じ図を出力出来ました。Jupyter Notebook上で実行すれば、描画結果をそのまま右クリックして画像として保存も出来ます。
2020/11/27追記: クラス名も決定木に表示
class_name
オプションを追加すると最終的に分類されたクラス名も表示出来ます。
plt.figure(figsize=(15, 10))
plot_tree(clf, feature_names=iris.feature_names, class_names=iris.target_names, filled=True)
plt.show()
まとめ
scikit-learnのtree.plot_treeと従来のGraphVizを用いる方法を決定木の可視化に対して行い、tree.plot_treeが(従来の方法より)簡単かつ便利だと実感しました。今後積極的に活用していきたいと思います。