機械学習レシピ#1では分類器として決定木を利用。今回はその可視化とどう処理されるのかについて。
分類器にはいくつかの種類がある。
- ANN(artificial neural network)
- SVM(Support Vector Machine)
前回決定木を利用した理由は、分類器がなぜ決定をするのかが正確に理解できる(理解しやすい)モデルのひとつのため。
今回利用するデータセット -アイリス-
https://en.wikipedia.org/wiki/Iris_flower_data_set
アイリスは、ある花がどんな種類か花弁の長さや幅のような種々の測定値に基いて特定できる。
データセットには3種の異なる花(setona,versicolor,virginica)が含まれている。全部で150例(50例x3)を用意。
各例では4つの特徴量(萼と花弁の長さと幅)が使われている。
Sepal length | Sepal width | Petal length | Petal width | Species |
---|---|---|---|---|
5.1 | 3.5 | 1.4 | 0.2 | I. setosa |
4.9 | 3.0 | 1.4 | 0.2 | I. setosa |
7.0 | 3.2 | 4.7 | 1.4 | I. versicolor |
目標は
- このデータセットを使って
- 分類器を学習させ
- そしてその分類機を使って新しい花を与えれば、どんな種類の花か予測
- さらにTree情報の可視化を行うこと
1.データのインポート
irisのデータセットをインポート。データ(花の名前)とメタデータ(特徴量)両方を含んでいる。
>>> from sklearn.datasets import load_iris
>>> iris = load_iris()
>>> print iris.feature_names
['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
>>> print iris.target_names
['setosa' 'versicolor' 'virginica']
dta変数には特徴量が入っている。
>>> print iris.data[0]
[ 5.1 3.5 1.4 0.2]
target変数にはラベルが入っている。
>>> print iris.target[0]
0 ※0なのでsetosaを表す
サンプルデータのいくつかを後ほど分類器がどれくらいの精度かテストするためにとっておく。このデータをテストデータと呼ぶ。プログラミングと同様にテストは非常に大事。
ただ、今回は3つの例(各種の花の1例ずつ)のみテストデータとして除外しておく。
学習用に大部分のデータを使用する。
# training data
train_target = np.delete(iris.target,test_idx)
train_data = np.delete(iris.data, test_idx, axis=0)
# testing data
test_target = iris.target[test_idx]
test_data = iris.data[test_idx]
2. 分類器を学習させる
決定木分類機を作って学習データについて学習させる。
from sklearn import tree
clf = tree.DecisionTreeClassifier()
clf.fit(train_data, train_target)
3. 新しい花の予測をする
分類器が予測するもの
>>> print test_target
[0 1 2]
分類器を使って予測したテストデータの結果は上記と同じになっている。
>>>print clf.predict(test_data)
[0 1 2]
4. 木を可視化して分類器の働き方をみる
pip install pydotplus
pip install graphviz
pip install --upgrade IPython
from sklearn.externals.six import StringIO
import pydotplus
dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
impurity=False)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf("iris.pdf")
open -a preview iris.pdf

各ノードは特徴量の1つについて、”はい”か”いいえ”を尋ねる。例えば一番上のノードでは花弁(petal)の幅0.8cm以下かどうか訪ねている。その答えがTrueなら左、Falseなら右へ進む。
>>> print iris.feature_names, iris.target_names
['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'] ['setosa' 'versicolor' 'virginica']
>>> print test_data[1], test_target[1]
[ 7. 3.2 4.7 1.4] 1
- pental widthは1.4で0.8より大きいのでFalse.右へ
- pental widthが1.75以下かどうか尋ねられてるおり、1.4で小さいのでTrue.左へ
- pental lengthが4.95以下か尋ねられており、4.7で小さいのでTrue.左へ
- pental widthが1.65以下か尋ねられており、1.4で小さいのでTrue.
- その結果、予測はversicolorになる。
