#1. はじめに
決定木学習では、決定木の根から初めて、情報利得が最大となるように分割を行っていきます。
今回はその決定境界がどのようなイメージで選定されるのかを可視化してみました。
#2. 決定木学習の題材
データセット:IRISデータ
X[0]:花びらの長さ
X[1]:花びらの幅
クラス数:3
#データ準備
from sklearn import datasets
import numpy as np
iris = datasets.load_iris()
X = iris.data[:,[2,3]]
y = iris.target
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
X,y,test_size = 0.3, random_state=1,stratify=y)
#計算部
from sklearn.tree import DecisionTreeClassifier
tree_model = DecisionTreeClassifier(criterion='gini', max_depth=4, random_state=1)
tree_model.fit(X_train,y_train)
#描画部
import matplotlib.pyplot as plt
from sklearn import tree
tree.plot_tree(tree_model)
plt.show()
決定木は以下の通りになりました。
今回は1つ目の決定境界、X[1]<=0.75の選定について、イメージ化を行いました。
#3. 決定境界選定のイメージ
下図左はX[1]について、決定境界を動かした際の散布図、
右はそれに伴う分離後のジニ不純度になります。
分離後の不純度は低いほど、情報利得が大きくなるため、良いとします。
X[1] = 0.75で分離した際の図は以下の通りになります。
クラスラベル0は完全に分離されており、不純度が低いことが分かります。
また、ジニ不純度はこの際に最低の値をとります。
X[1]=1.35で分離した際の図は以下の通りになります。
クラスラベル1が両方の領域に分かれており、不純度が高いことがわかります。
また、ジニ不純度はこの際に最高の値をとります。
X[1]=1.8で分離した際の図は以下の通りになります。
クラスラベル2がおおよそ分離されていることが分かります。
また、ジニ不純度は再度低くなりますが、X[1]=0.75の際より大きいです。
#4. (備考)GIF画像作成
##4-1. 画像作成
GIFのもとになる画像作成コードです。
#ジニ関数定義
def gini(p):
sum = 1
for i in p:
sum -= (i*i)
return sum
#ジニ不純度計算
def calc(cri,X_train,y_train):
sep = np.where(X_train[:,1]<cri,0,1)
y_train0 = y_train[sep == 0]
y_train1 = y_train[sep == 1]
yb_train0 = np.bincount(y_train0)
yb_train1 = np.bincount(y_train1)
ybn_train0 = yb_train0/sum(yb_train0)
ybn_train1 = yb_train1/sum(yb_train1)
gini0 = gini(ybn_train0)
gini1 = gini(ybn_train1)
if ybn_train0.size ==0: gini0 = 0
if ybn_train1.size ==0: gini1 = 0
gini_s = gini0 + gini1
return gini_s
def make_gini_graph(cri,path):
plt.rcParams['figure.subplot.bottom'] = 0.15
fig, ax = plt.subplots(nrows=1, ncols=2,figsize=(6,3))
fig.subplots_adjust(wspace = 0.4)
ax[0].scatter(X_train[y_train == 0,0],X_train[y_train ==0,1],c='b',marker='x')
ax[0].scatter(X_train[y_train == 1,0],X_train[y_train ==1,1],c='r',marker='o')
ax[0].scatter(X_train[y_train == 2,0],X_train[y_train ==2,1],c='y',marker='s')
X_ax,Y_ax = np.meshgrid(np.linspace(0, 7, 700), np.linspace(0, 3, 300))
Z = np.where(Y_ax <= cri,0,1)
ax[0].contourf(X_ax, Y_ax,Z,alpha=0.2,cmap = 'bwr')
ax[0].set_xlabel('X[0]')
ax[0].set_ylabel('X[1]')
X1 = range(300)
X_ax1 = [i * 0.01 for i in X1]
X2 = np.ones(2)
X_ax2 = [i*cri for i in X2]
y_ax = np.linspace(0.45,0.95,2)
ax[1].plot(X_ax1,results,c='b')
ax[1].plot(X_ax2, y_ax ,c='r')
ax[1].set_xlabel('X[1]')
ax[1].set_ylabel('gini')
ax[1].set_xticklabels([" ", 0.0,1.0,2.0,3.0])
fig.savefig(path)
plt.close()
import os
import matplotlib.pyplot as plt
results = []
for i in range(300):
cri = i * 0.01
result = calc(cri,X_train,y_train)
results.append(result)
for i in range (300):
file = "image\\{:04d}.png".format(i)
path = os.path.join(os.path.dirname(__file__), file)
make_gini_graph(i*0.01,path)
##4-2. GIF作成
GIF画像の作成コードです。
from PIL import Image
import glob
files = sorted(glob.glob('./image/*.png'))
images = list(map(lambda file : Image.open(file) , files))
print(images)
images[0].save('./image/image.gif' , save_all = True , append_images = images[1:] , duration = 50 , loop = 0)
#5. まとめ
決定木について、イメージがつかず試してみたところ、
少し理解が深まったと思います。
もう少し勉強してみようと思います。