3
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

scikit-learnで決定木

Last updated at Posted at 2018-04-26

決定木がどんな感じで分類するのか確認してみる。

import matplotlib.pyplot as plt
import numpy as np
from sklearn.tree import DecisionTreeClassifier

np.random.seed(20180426)  # 乱数の出方を固定

X = np.array([[i, j] for i, j in zip(np.random.normal(2, 1, 100), np.random.normal(2, 1, 100))])
y = np.array([0] * 100)

X = np.append(X, np.array([[i, j] for i, j in zip(np.random.normal(3, 1, 100), np.random.normal(3, 1, 100))]), axis=0)
y = np.append(y, [1] * 100)


def draw_tree(n):
    clf = DecisionTreeClassifier(max_depth=n)

    clf.fit(X, y)

    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
                         np.arange(y_min, y_max, 0.1))
    color_set = ['b' if i == 0 else 'r' for i in y]

    plt.figure(figsize=(8, 8))

    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)

    plt.contourf(xx, yy, Z, alpha=0.4)
    plt.scatter(X[:, 0], X[:, 1], c=color_set, s=30, edgecolor='k')
    plt.title("木の深さ: " + str(n))

    # plt.savefig("graph" + str(n) + ".png")
    plt.show()


for i in range(1, 11):
    draw_tree(i)

graph1.png
まだアバウト。

graph2.png
graph3.png
graph4.png

だいぶ良くなってきたか?

graph5.png
graph6.png

流石に過学習している感ある。最終的にはこうなった。

graph10.png

交差検証でスコアがどれくらいになるのか見てみよう。

import matplotlib.pyplot as plt
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import cross_val_score

np.random.seed(20180426)  # 乱数の出方を固定

X = np.array([[i, j] for i, j in zip(np.random.normal(2, 1, 100), np.random.normal(2, 1, 100))])
y = np.array([0] * 100)

X = np.append(X, np.array([[i, j] for i, j in zip(np.random.normal(3, 1, 100), np.random.normal(3, 1, 100))]), axis=0)
y = np.append(y, [1] * 100)


def calculate_score(n):
    clf = DecisionTreeClassifier(max_depth=n)

    clf.fit(X, y)

    scores = cross_val_score(clf, X, y, cv=5)
    return np.average(scores)


xx = [i for i in range(1, 11)]
yy = [calculate_score(i) for i in range(1, 11)]

plt.xlabel("木の深さ")
plt.ylabel("スコア")
plt.xticks(np.arange(0, 11, 1))
plt.plot(xx, yy)
# plt.savefig("result.png")
plt.show()

result.png

やっぱり深さ4くらいが一番マシだった。

3
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?