Help us understand the problem. What is going on with this article?

scikit-learnの決定木系モデルを視覚化する方法

More than 3 years have passed since last update.

概要

scikit-learn の決定木系のモデルを視覚化する方法についてのエントリーです。
最近良く使うので、備忘録&My チートシート代わりに書きます。
このエントリーでは、Windows版のPython3.5.2でサンプルコードを組んでいます。

環境の準備

 決定木の視覚化にあたって必要なコンポーネントは以下の通りです。
- scikit-learn
- Graphviz
- pydotplus

 Graphvizは、OSごとにインストール方法が異なります。
 scikit-learn は、デフォルトで入っていることが多いのではないでしょうか。
 一方 pydotplus は、pipでインストールする必要があるでしょう。

Graphviz のインストール(Windows7環境)

 Graphviz は、Graph Visualization Software のことです。
 DOT言語で記述された内容を画像にしてくれるライブラリです。
 詳細はこちらを読んでください。
 http://www.graphviz.org/Documentation.php

 ダウンロードページは以下の通りです。
 
- Window版
 http://www.graphviz.org/Download_windows.php
 
- RHEL,CentOS版
 http://www.graphviz.org/Download_linux_rhel.php
 
- ubunts版
 http://www.graphviz.org/Download_linux_ubuntu.php
 
- ソース版
 http://www.graphviz.org/Download_source.php
 
 
 Windows環境でインストールするので、以下のページからMSIファイルをダウンロードして実行します。
 http://www.graphviz.org/Download_windows.php

ダウンロードしたMSIファイルを実行すると、まず以下の画面が表示されます。
 「Next」をクリックして画面を進めます。
install_01.png

 エントリー執筆時点(2017/09/03)のバージョンは2.38です。
 ここでは全てのユーザーが使えるように「Everyone」を選んだ状態で次に進みます。
install_02.png

 

 インストールの準備ができたことを知らせるメッセージです。 
「Next」を押して進みます。
install_03.png

 コンポーネントのインストールが進むにつれて、インジケーターのゲージが満ちていきます。
 インジケータが完全に満たされたら「Next」をクリックします。
install_04.png

 無事にGraphvizのインストールが完了しました。
 「Close」をクリックして、ウィンドウを閉じます。
install_05.png

 続いて Pydotplus のインストールに移ります。

pydotplus

 
 先に述べたDOt言語を扱うためのpythonモジュールです。
 今回はWindows環境なので、Anaconda Prompt で作業を進めます。

launch_anaconda_prompt.png

Anaconda Prompt起動したら、"pip install pydotplus"というコマンドを実行します。

(C:\Program Files\Anaconda3) C:\Users\usr********>pip install pydotplus
Collecting pydotplus
  Downloading pydotplus-2.0.2.tar.gz (278kB)
    100% |################################| 286kB 860kB/s
Requirement already satisfied: pyparsing>=2.0.1 in c:\program files\anaconda3\li
b\site-packages (from pydotplus)
Building wheels for collected packages: pydotplus
  Running setup.py bdist_wheel for pydotplus ... done
  Stored in directory: C:\Users\usr********\AppData\Local\pip\Cache\wheels\43\31\
48\e1d60511537b50a8ec28b130566d2fbbe4ac302b0def4baa48
Successfully built pydotplus
Installing collected packages: pydotplus
Successfully installed pydotplus-2.0.2

 成功すると上記のような出力がされます。
 エラーが発生して中断されなければ、pydotplus のインストール作業は以上で完了となります。

環境変数の設定

 
 次に、pydotplus に Graphviz のインストールパスを認識させるべく、環境変数を編集します。
 まず、graphvizがインストールされた場所にある "bin" ディレクトリの位置を確認します。
スタートメニューのリストにある「gvedit.exe」のプロパティを見ると確認できます。
install_path.png

 このパス("C:\Program Files (x86)\Graphviz2.38/bin")を、環境変数pathに追加します。
setup_env.png

 path を変更したら、PythonのIDE(PyCharmなど)を再起動しておきます。

決定木系モデルの視覚化のサンプルコード

 
 おなじみの irisデータセットを使って RandomForestモデルを作り、そのモデルの中から決定木モデルを一つ取り出して視覚化(=png画像で出力)してみました。
 そのサンプルコードは以下の通りです。
 内部処理を把握するためのデバッグ用print文をそのまま残してあります。
 

u"""
    決定木系モデルを視覚化する。
    Graphviz を用いて、決定木のモデルを視覚化する。
    決定木だけでなく、ランダムフォレストなど木構造のモデルに適用できる。

"""

from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.cross_validation import train_test_split
from sklearn.model_selection import cross_val_score

# モデルの木構造の視覚化に必要なパッケージ
from sklearn import tree
import pydotplus as pdp

import pandas as pd
import numpy as np

iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)

print(df.head(5))
print(iris.target)
print(iris.target_names)
df['species'] = pd.Categorical.from_codes(iris.target, iris.target_names)
print(df.head(5))

# 学習データとテストデータを分ける

features = df.columns[:4]
label = df["species"]
print(features)
print(label)
print(df[features].head(5))
df_train, df_test, label_train, label_test = train_test_split(df[features], label)

clf = RandomForestClassifier(n_estimators=150)
clf.fit(df_train, label_train)
print("========================================================")
print("予測の精度")
print(clf.score(df_test, label_test))

# 試しに木の一つを視覚化する
estimators = clf.estimators_
file_name = "./tree_visualization.png"
dot_data = tree.export_graphviz(estimators[0], # 決定木オブジェクトを一つ指定する
                                out_file=None, # ファイルは介さずにGraphvizにdot言語データを渡すのでNone
                                filled=True, # Trueにすると、分岐の際にどちらのノードに多く分類されたのか色で示してくれる
                                rounded=True, # Trueにすると、ノードの角を丸く描画する。
                                feature_names=features, # これを指定しないとチャート上で特徴量の名前が表示されない
                                class_names=iris.target_names, # これを指定しないとチャート上で分類名が表示されない
                                special_characters=True # 特殊文字を扱えるようにする
                                )
graph = pdp.graph_from_dot_data(dot_data)
graph.write_png(file_name)

 それでは各部を見ていきます。
 iris データセットを読み込んで、学習データを用意しているのは以下の部分です。
 特徴量の名前ば iris.feature_names にセットされています。
 目的変数(=アヤメの種類)は、iris.target にセットされています。
 ただし、iris.targetは数字であり、人が読むのには不親切な状態です。
 そこで iris.target_names にある種類名の表記を使い、人間が読める(=human-redable)な目的変数をdf['species'] にセットしています。

iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)

print(df.head(5))
print(iris.target)
print(iris.target_names)
df['species'] = pd.Categorical.from_codes(iris.target, iris.target_names)
print(df.head(5))

 続いて、RandomForestモデルを作る部分のコードです。
 学習データを格納した df は、目的変数も含んでいます。
 特徴量の部分と目的変数を分けて、モデルにインプットする必要があります。
 なので、特徴量部分は features に、目的変数は label にセットします。
 そして、train_test_split でモデル学習用とテスト用のデータに分割します。
 clf には、RandomForestオブジェクトがセットされます。使用する決定木の数は150個としています。(引数:n_estimator=150)
 あとは 学習用データを指定して、fit()メソッドでモデルに学習させます。

features = df.columns[:4]
label = df["species"]
print(features)
print(label)
print(df[features].head(5))
df_train, df_test, label_train, label_test = train_test_split(df[features], label)

clf = RandomForestClassifier(n_estimators=150)
clf.fit(df_train, label_train)

 そしていよいよ視覚化です。
 RandomForestオブジェクトは estimators_ というプロパティを持っています。
 estimators_ は、決定木オブジェクトのリストです。
 ここではサンプルとして一番目の決定木オブジェクト(estimators[0])を視覚化します。
 png画像ファイル "tree_visualization.png" として出力します。
 tree.export_graphviz() が視覚化処理をしています。
 引数の説明はコードのコメント中に記述しました。
引数をちゃんと指定しないと、特徴量名、分類名ともに表示されないので注意が必要です。

# 試しに木の一つを視覚化する
estimators = clf.estimators_
file_name = "./tree_visualization.png"
dot_data = tree.export_graphviz(estimators[0], # 決定木オブジェクトを一つ指定する
                                out_file=None, # ファイルは介さずにGraphvizにdot言語データを渡すのでNone
                                filled=True, # Trueにすると、分岐の際にどちらのノードに多く分類されたのか色で示してくれる
                                rounded=True, # Trueにすると、ノードの角を丸く描画する。
                                feature_names=features, # これを指定しないとチャート上で特徴量の名前が表示されない
                                class_names=iris.target_names, # これを指定しないとチャート上で分類名が表示されない
                                special_characters=True # 特殊文字を扱えるようにする
                                )
graph = pdp.graph_from_dot_data(dot_data)
graph.write_png(file_name)

 すると以下のように決定木が視覚化されたものが、png画像として得られます。

 tree_visualization.png

takahashi_yukou
データサイエンティスト&機械学習エンジニアとして、機械学習、自然言語処理の様々なプロジェクトに携わっています。最近は、新型コロナの統計データの分析を始めました。
https://yuukou-exp.plus/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away