LoginSignup
2
4

More than 1 year has passed since last update.

決定木学習の決定境界選定をイメージでとらえる

Last updated at Posted at 2022-01-01

#1. はじめに

決定木学習では、決定木の根から初めて、情報利得が最大となるように分割を行っていきます。
今回はその決定境界がどのようなイメージで選定されるのかを可視化してみました。

#2. 決定木学習の題材

データセット:IRISデータ
X[0]:花びらの長さ
X[1]:花びらの幅
クラス数:3

original.png

#データ準備
from sklearn import datasets
import numpy as np

iris = datasets.load_iris()

X = iris.data[:,[2,3]]
y = iris.target

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
	X,y,test_size = 0.3, random_state=1,stratify=y)

#計算部
from sklearn.tree import DecisionTreeClassifier

tree_model = DecisionTreeClassifier(criterion='gini', max_depth=4, random_state=1)
tree_model.fit(X_train,y_train)

#描画部
import matplotlib.pyplot as plt
from sklearn import tree

tree.plot_tree(tree_model)
plt.show()

決定木は以下の通りになりました。

tree.png

今回は1つ目の決定境界、X[1]<=0.75の選定について、イメージ化を行いました。

#3. 決定境界選定のイメージ

下図左はX[1]について、決定境界を動かした際の散布図、
右はそれに伴う分離後のジニ不純度になります。

分離後の不純度は低いほど、情報利得が大きくなるため、良いとします。

image-min.gif

X[1] = 0.75で分離した際の図は以下の通りになります。
クラスラベル0は完全に分離されており、不純度が低いことが分かります。
また、ジニ不純度はこの際に最低の値をとります。

0075.png

X[1]=1.35で分離した際の図は以下の通りになります。
クラスラベル1が両方の領域に分かれており、不純度が高いことがわかります。
また、ジニ不純度はこの際に最高の値をとります。

0136.png

X[1]=1.8で分離した際の図は以下の通りになります。
クラスラベル2がおおよそ分離されていることが分かります。
また、ジニ不純度は再度低くなりますが、X[1]=0.75の際より大きいです。

0182.png

#4. (備考)GIF画像作成
##4-1. 画像作成

GIFのもとになる画像作成コードです。

#ジニ関数定義
def gini(p):
	sum = 1
	for i in p:
		sum -= (i*i)
	return sum

#ジニ不純度計算
def calc(cri,X_train,y_train):

	sep = np.where(X_train[:,1]<cri,0,1)

	y_train0 = y_train[sep == 0]
	y_train1 = y_train[sep == 1]

	yb_train0 = np.bincount(y_train0)
	yb_train1 = np.bincount(y_train1)

	ybn_train0 = yb_train0/sum(yb_train0)
	ybn_train1 = yb_train1/sum(yb_train1)

	gini0 = gini(ybn_train0)
	gini1 = gini(ybn_train1)

	if ybn_train0.size ==0: gini0 = 0
	if ybn_train1.size ==0: gini1 = 0

	gini_s = gini0 + gini1

	return gini_s

def make_gini_graph(cri,path):

	plt.rcParams['figure.subplot.bottom'] = 0.15
	fig, ax = plt.subplots(nrows=1, ncols=2,figsize=(6,3))
	fig.subplots_adjust(wspace = 0.4)

	ax[0].scatter(X_train[y_train == 0,0],X_train[y_train ==0,1],c='b',marker='x')
	ax[0].scatter(X_train[y_train == 1,0],X_train[y_train ==1,1],c='r',marker='o')
	ax[0].scatter(X_train[y_train == 2,0],X_train[y_train ==2,1],c='y',marker='s')

	X_ax,Y_ax = np.meshgrid(np.linspace(0, 7, 700), np.linspace(0, 3, 300))
	Z = np.where(Y_ax <= cri,0,1)

	ax[0].contourf(X_ax, Y_ax,Z,alpha=0.2,cmap = 'bwr')

	ax[0].set_xlabel('X[0]')
	ax[0].set_ylabel('X[1]')

	X1 = range(300)
	X_ax1 = [i * 0.01 for i in X1]
	X2 = np.ones(2)
	X_ax2 = [i*cri for i in X2]
	y_ax = np.linspace(0.45,0.95,2)

	ax[1].plot(X_ax1,results,c='b')
	ax[1].plot(X_ax2, y_ax ,c='r')

	ax[1].set_xlabel('X[1]')
	ax[1].set_ylabel('gini')

	ax[1].set_xticklabels([" ", 0.0,1.0,2.0,3.0])

	fig.savefig(path)
	plt.close()

import os
import matplotlib.pyplot as plt

results = []
for i in range(300):
	cri = i * 0.01
	result = calc(cri,X_train,y_train)
	results.append(result)

for i in range (300):
	file = "image\\{:04d}.png".format(i)
	path = os.path.join(os.path.dirname(__file__), file)
	make_gini_graph(i*0.01,path)

##4-2. GIF作成

GIF画像の作成コードです。

from PIL import Image
import glob

files = sorted(glob.glob('./image/*.png'))  
images = list(map(lambda file : Image.open(file) , files))
print(images)
images[0].save('./image/image.gif' , save_all = True , append_images = images[1:] , duration = 50 , loop = 0)

#5. まとめ

決定木について、イメージがつかず試してみたところ、
少し理解が深まったと思います。

もう少し勉強してみようと思います。

2
4
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
4