1
0

More than 3 years have passed since last update.

決定木(load_iris)

Last updated at Posted at 2020-11-24

■ はじめに

今回は、決定木の実装~プロットをまとめていきます。

【対象とする読者の方】
・決定木における、基礎のコードを学びたい方
・理論は詳しく分からないが、実装を見てイメージをつけたい方 など

1. モジュールの用意

最初に、必要なモジュールをインポートしておきます。


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import plot_tree


2. データの準備

load_iris データセットを使用します。


iris = load_iris()
X, y = iris.data[:, [0, 2]], iris.target

print(X.shape)
print(y.shape)

# (150, 2)
# (150,)

trainとtestデータに分割します。


X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3, random_state = 123)

print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_train.shape)

# (105, 2)
# (105,)
# (45, 2)
# (45,)

決定木では個々の特徴量は独立に処理され、データの分割はスケールに依存しないため
正規化や標準化は不要となります。

3. データの可視化

モデリングする前に、データをプロットして見ておきます。


fig, ax = plt.subplots()

ax.scatter(X_train[y_train == 0, 0], X_train[y_train == 0, 1], 
           marker = 'o', label = 'Setosa')

ax.scatter(X_train[y_train == 1, 0], X_train[y_train == 1, 1],
           marker = 'x', label = 'Versicolor')

ax.scatter(X_train[y_train == 2, 0], X_train[y_train == 2, 1],
           marker = 'x', label = 'Varginica')

ax.set_xlabel('Sepal Length')
ax.set_ylabel('Petal Length')
ax.legend(loc = 'best')

plt.show()

image.png

4. モデルの作成

決定木のモデルを作成します。


tree = DecisionTreeClassifier(max_depth = 3)
tree.fit(X_train, y_train)

'''
DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',
                       max_depth=3, max_features=None, max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, presort='deprecated',
                       random_state=None, splitter='best')

'''

併せて、可視化もしておきます。


fig, ax = plt.subplots(figsize=(10, 10))
plot_tree(tree, feature_names=iris.feature_names, filled=True)
plt.show()

image.png

5. 予測値の出力

テストデータに対する予測を行います。


y_pred = tree.predict(X_test)

print(y_pred[:10])
print(y_test[:10])

# [2 2 2 1 0 1 1 0 0 1]
# [1 2 2 1 0 2 1 0 0 1]

0:Setosa 1:Versicolor 2:Verginica

6. 性能評価

今回の分類予測における、正解率を求めます。


print('{:.3f}'.format(tree.score(X_test, y_test)))

# 0.956


■ 最後に

今回は上記1~6の手順をもとに、決定木のモデル作成・評価を行いました。
初学者の方にとって、少しでもお役に立てたようでしたら幸いです。

■ 参考文献

Pythonによるあたらしいデータ分析の教科書

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0