69
65

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

scikit-learnの決定木系モデルを視覚化する方法

Posted at

概要

scikit-learn の決定木系のモデルを視覚化する方法についてのエントリーです。
最近良く使うので、備忘録&My チートシート代わりに書きます。
このエントリーでは、Windows版のPython3.5.2でサンプルコードを組んでいます。

環境の準備

 決定木の視覚化にあたって必要なコンポーネントは以下の通りです。

  • scikit-learn
  • Graphviz
  • pydotplus

 Graphvizは、OSごとにインストール方法が異なります。
 scikit-learn は、デフォルトで入っていることが多いのではないでしょうか。
 一方 pydotplus は、pipでインストールする必要があるでしょう。

Graphviz のインストール(Windows7環境)

 Graphviz は、Graph Visualization Software のことです。
 DOT言語で記述された内容を画像にしてくれるライブラリです。
 詳細はこちらを読んでください。
 http://www.graphviz.org/Documentation.php

 ダウンロードページは以下の通りです。
 

ダウンロードしたMSIファイルを実行すると、まず以下の画面が表示されます。
 「Next」をクリックして画面を進めます。
install_01.png

 エントリー執筆時点(2017/09/03)のバージョンは2.38です。
 ここでは全てのユーザーが使えるように「Everyone」を選んだ状態で次に進みます。
install_02.png

 

 インストールの準備ができたことを知らせるメッセージです。 
「Next」を押して進みます。
install_03.png

 コンポーネントのインストールが進むにつれて、インジケーターのゲージが満ちていきます。
 インジケータが完全に満たされたら「Next」をクリックします。
install_04.png

 無事にGraphvizのインストールが完了しました。
 「Close」をクリックして、ウィンドウを閉じます。
install_05.png

 続いて Pydotplus のインストールに移ります。

pydotplus

 
 先に述べたDOt言語を扱うためのpythonモジュールです。
 今回はWindows環境なので、Anaconda Prompt で作業を進めます。

launch_anaconda_prompt.png

Anaconda Prompt起動したら、"pip install pydotplus"というコマンドを実行します。

(C:\Program Files\Anaconda3) C:\Users\usr********>pip install pydotplus
Collecting pydotplus
  Downloading pydotplus-2.0.2.tar.gz (278kB)
    100% |################################| 286kB 860kB/s
Requirement already satisfied: pyparsing>=2.0.1 in c:\program files\anaconda3\li
b\site-packages (from pydotplus)
Building wheels for collected packages: pydotplus
  Running setup.py bdist_wheel for pydotplus ... done
  Stored in directory: C:\Users\usr********\AppData\Local\pip\Cache\wheels\43\31\
48\e1d60511537b50a8ec28b130566d2fbbe4ac302b0def4baa48
Successfully built pydotplus
Installing collected packages: pydotplus
Successfully installed pydotplus-2.0.2

 成功すると上記のような出力がされます。
 エラーが発生して中断されなければ、pydotplus のインストール作業は以上で完了となります。

環境変数の設定

 
 次に、pydotplus に Graphviz のインストールパスを認識させるべく、環境変数を編集します。
 まず、graphvizがインストールされた場所にある "bin" ディレクトリの位置を確認します。
スタートメニューのリストにある「gvedit.exe」のプロパティを見ると確認できます。
install_path.png

 このパス("C:\Program Files (x86)\Graphviz2.38/bin")を、環境変数pathに追加します。
setup_env.png

 path を変更したら、PythonのIDE(PyCharmなど)を再起動しておきます。

決定木系モデルの視覚化のサンプルコード

 
 おなじみの irisデータセットを使って RandomForestモデルを作り、そのモデルの中から決定木モデルを一つ取り出して視覚化(=png画像で出力)してみました。
 そのサンプルコードは以下の通りです。
 内部処理を把握するためのデバッグ用print文をそのまま残してあります。
 

u"""
    決定木系モデルを視覚化する。
    Graphviz を用いて、決定木のモデルを視覚化する。
    決定木だけでなく、ランダムフォレストなど木構造のモデルに適用できる。

"""

from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.cross_validation import train_test_split
from sklearn.model_selection import cross_val_score

# モデルの木構造の視覚化に必要なパッケージ
from sklearn import tree
import pydotplus as pdp

import pandas as pd
import numpy as np

iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)

print(df.head(5))
print(iris.target)
print(iris.target_names)
df['species'] = pd.Categorical.from_codes(iris.target, iris.target_names)
print(df.head(5))

# 学習データとテストデータを分ける

features = df.columns[:4]
label = df["species"]
print(features)
print(label)
print(df[features].head(5))
df_train, df_test, label_train, label_test = train_test_split(df[features], label)

clf = RandomForestClassifier(n_estimators=150)
clf.fit(df_train, label_train)
print("========================================================")
print("予測の精度")
print(clf.score(df_test, label_test))

# 試しに木の一つを視覚化する
estimators = clf.estimators_
file_name = "./tree_visualization.png"
dot_data = tree.export_graphviz(estimators[0], # 決定木オブジェクトを一つ指定する
                                out_file=None, # ファイルは介さずにGraphvizにdot言語データを渡すのでNone
                                filled=True, # Trueにすると、分岐の際にどちらのノードに多く分類されたのか色で示してくれる
                                rounded=True, # Trueにすると、ノードの角を丸く描画する。
                                feature_names=features, # これを指定しないとチャート上で特徴量の名前が表示されない
                                class_names=iris.target_names, # これを指定しないとチャート上で分類名が表示されない
                                special_characters=True # 特殊文字を扱えるようにする
                                )
graph = pdp.graph_from_dot_data(dot_data)
graph.write_png(file_name)

 それでは各部を見ていきます。
 iris データセットを読み込んで、学習データを用意しているのは以下の部分です。
 特徴量の名前ば iris.feature_names にセットされています。
 目的変数(=アヤメの種類)は、iris.target にセットされています。
 ただし、iris.targetは数字であり、人が読むのには不親切な状態です。
 そこで iris.target_names にある種類名の表記を使い、人間が読める(=human-redable)な目的変数をdf['species'] にセットしています。

iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)

print(df.head(5))
print(iris.target)
print(iris.target_names)
df['species'] = pd.Categorical.from_codes(iris.target, iris.target_names)
print(df.head(5))

 続いて、RandomForestモデルを作る部分のコードです。
 学習データを格納した df は、目的変数も含んでいます。
 特徴量の部分と目的変数を分けて、モデルにインプットする必要があります。
 なので、特徴量部分は features に、目的変数は label にセットします。
 そして、train_test_split でモデル学習用とテスト用のデータに分割します。
 clf には、RandomForestオブジェクトがセットされます。使用する決定木の数は150個としています。(引数:n_estimator=150)
 あとは 学習用データを指定して、fit()メソッドでモデルに学習させます。

features = df.columns[:4]
label = df["species"]
print(features)
print(label)
print(df[features].head(5))
df_train, df_test, label_train, label_test = train_test_split(df[features], label)

clf = RandomForestClassifier(n_estimators=150)
clf.fit(df_train, label_train)

 そしていよいよ視覚化です。
 RandomForestオブジェクトは estimators_ というプロパティを持っています。
 estimators_ は、決定木オブジェクトのリストです。
 ここではサンプルとして一番目の決定木オブジェクト(estimators[0])を視覚化します。
 png画像ファイル "tree_visualization.png" として出力します。
 tree.export_graphviz() が視覚化処理をしています。
 引数の説明はコードのコメント中に記述しました。
引数をちゃんと指定しないと、特徴量名、分類名ともに表示されないので注意が必要です。

# 試しに木の一つを視覚化する
estimators = clf.estimators_
file_name = "./tree_visualization.png"
dot_data = tree.export_graphviz(estimators[0], # 決定木オブジェクトを一つ指定する
                                out_file=None, # ファイルは介さずにGraphvizにdot言語データを渡すのでNone
                                filled=True, # Trueにすると、分岐の際にどちらのノードに多く分類されたのか色で示してくれる
                                rounded=True, # Trueにすると、ノードの角を丸く描画する。
                                feature_names=features, # これを指定しないとチャート上で特徴量の名前が表示されない
                                class_names=iris.target_names, # これを指定しないとチャート上で分類名が表示されない
                                special_characters=True # 特殊文字を扱えるようにする
                                )
graph = pdp.graph_from_dot_data(dot_data)
graph.write_png(file_name)

 すると以下のように決定木が視覚化されたものが、png画像として得られます。

 tree_visualization.png

69
65
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
69
65

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?