※この記事は記載途中の箇所があります。またあまりまとまっておらずかなり読みにくいです。
決定木学習とは
分類や回帰問題で使用できる機械学習アルゴリズムの一種である。簡単のために分類問題のみを考える。
ロジスティック回帰やSVMなどの他の機械学習アルゴリズムと異なる点として、入力データの特徴量のうちどの特徴量を使用してどんな基準で分類を行っているのかが分かりやすい点がある。
決定木とは、以下のように入力データを受け取り、ある基準で入力データを分割し出力するノードを繰り返し生成することでできる木である。決定木学習では各ノードでどの特徴量をどんな基準で用いて分割したのかが分かる。
決定木が考案された背景
従来の機械学習アルゴリズムではどの特徴量を使いどんな基準で分類したのかが解釈しづらかったためと考えられる。解釈しづらいと、精度の向上・悪化要因が分からず改善が難しい問題が発生する。
決定木の原理
前述の通り入力データを受け取りいずれかの特徴量を用いて分割したいがどうすれば良いだろうか?具体例を使用して考えてみる。フィッシャーのアイリスデータを例とする。簡単のために2値分類とし、 セトナ(setosa)かそうでないかを予測することを目標とする。今やりたいことは、がく片長(Sepal Length)、がく片幅(Sepal Width)、花びら長(Petal Length)、花びら幅(Petal Width)の4つの特徴量を持ったいくつかのサンプルデータを受け取り、4つの特徴量のいずれか1つを用いてある基準によりデータを分割することである。
簡単なケース1:2値分類かつ特徴量1つ
これでもまだ難しいためもっと簡単なケースを考える。予測したいことは変えずに、入力データの特徴量ががく片のみであるとする。また、実際にがく片長のみでセトナかどうか分類できると仮定する。例として今回はがく片長が5cm以上であればセトナであり、5cm未満であればセトナでないとする。すなわち、決定木にはがく片長が5cm以上かどうかを基準として分割させたいとする。単純な方法の一つは、様々な基準で分割してみて最も精度良く分類できた基準を採用する方法だろう。実際にPythonでこの単純な方法を実装してみる。
まず入力データを用意する。サンプル数を100とする。予測すべき値は1のときセトナであり、0のときセトナでないとする。
import numpy as np
import matplotlib.pyplot as plt
import random
random.seed(0)
n_X = 100
X = np.array([])
y = np.array([])
for i in range(n_X):
if i < 50:
X = np.append(X, random.uniform(0.1, 4.999))
y = np.append(y, 0)
else:
X = np.append(X, random.uniform(5.001, 10))
y = np.append(y, 1)
plt.scatter(X[y==0], np.zeros((1, n_X//2)), c='red', label="not setosa")
plt.scatter(X[y==1], np.zeros((1, n_X//2)), c='blue', label="setosa")
plt.xlabel('Sepal length [cm]')
plt.legend()
※グラフの縦軸に意味はありません。1次元の特徴量の分布を描きたかったのですがやり方がわからず2次元プロットしたため表示上縦軸が在るように見えているだけです。
様々な基準で分割して各分割基準ごとに精度を計算します。精度は単純に正しく分類された割合とします。ある基準を与えたときに、その基準以上をセトナとするかその基準未満をセトナとするかは、両方を計算し精度が高かった方とします。
def calc_accuracy(threshold):
y_pred = np.zeros((y.shape))
y_pred[X < threshold] = 1 # 基準値未満をセトナとする
accuracy_less_than = len(y_pred[y == y_pred])
y_pred = np.zeros((y.shape))
y_pred[threshold <= X] = 1 # 基準値以上をセトナとする
accuracy_more_than_or_eq = len(y_pred[y == y_pred])
if accuracy_less_than < accuracy_more_than_or_eq:
accuracy = accuracy_more_than_or_eq
is_less_than = False
else:
accuracy = accuracy_less_than
is_less_than = True
return accuracy, is_less_than
thresholds = np.arange(0.1, 10.1, 0.1)
accuracies = []
are_less_than = []
for threshold in thresholds:
accuracy, is_less_than = calc_accuracy(threshold)
accuracies.append(accuracy)
are_less_than.append(is_less_than)
threshold_pred = thresholds[np.argmax(accuracies)]
is_less_than_pred = are_less_than[np.argmax(accuracies)]
accuracy = np.max(accuracies)
if is_less_than_pred:
print(f"setosa has sepal lenght which is less than {threshold_pred} cm with {accuracy}% accuracy")
else:
print(f"setosa has sepal lenght which is more than or equal to {threshold_pred} cm with {accuracy}% accuracy")
実行結果は以下の通りであった。
setosa has sepal lenght which is more than or equal to 5.0 cm with 100% accuracy
うまく計算できているようである。
簡単なケース2:マルチ分類かつ特徴量1
次にマルチ分類かつ特徴量が1つの場合を考える。入力データのサンプル数は同様に100とする。特徴量も同様にがく片長とするが、簡単のため以がく片長に応じて下のように値を割り当てる。がく片長 $x$ [cm]が
- $0 < x \le 3$ のとき $0$
- $3 < x \le 6$ のとき $1$
- $6 < x \le 9$ のとき $2$
とする。また、予測値はセトナ、バーシクル、バージニカのいずれかとする。それぞれ以下のようにラベルを割り当てる。
- セトナ:0
- バーシクル:1
- バージニカ:2
真値(予測すべき答え)は、がく片長 $x$ [cm]が
- $0 < x \le 3$ のときセトナ
- $3 < x \le 6$ のときバーシクル
- $6 < x \le 9$ のときバージニカ
とする。
単純な方法は特徴量 $0, 1, 2$ ごとに分割する方法である。あまりにも単純な方法であるが、今回の場合これでうまく分類することができる。(そうなるようにデータを仮定したから当たり前であるが。)ここで以下のような疑問が生じるだろう。
がく片長に応じた値の割り当て方はどうやって決めるの?
今回はたまたま真値と同様にがく片長に応じて3つの値を割り当てたからうまくいったものの、例えばがく片長に応じた値の割り当て方を
- $0 < x \le 0.1$ のとき $0$
- $0.1 < x \le 0.2$ のとき $1$
- $0.2 < x \le 9$ のとき $2$
とした場合、単純に特徴量 $0, 1, 2$ ごとに分割すると全くうまく分類できない。この問題への対応方法は後で考えることとしてまずはこの単純な手法をマルチ分類かつ特徴量2つの場合に拡張してみる。
マルチ分類かつ特徴量2つ
特徴量ががく片長、がく片幅の2つとする。先ほどと同様にがく片長に応じて下のように値を割り当てる。がく片長 $x_{1}$ [cm]が
- $0 < x_{1} \le 3$ のとき $0$
- $3 < x_{1} \le 6$ のとき $1$
- $6 < x_{1} \le 9$ のとき $2$
とする。
がく片幅についても同様にして以下のように値を割り当てる。がく片幅 $x_{2}$ [cm]が
- $0 < x_{2} \le 3$ のとき $0$
- $3 < x_{2} \le 6$ のとき $1$
- $6 < x_{2} \le 9$ のとき $2$
とする。
決定木学習では、2つの特徴量のうちどちらか1つを選んで分割し、その後選ばなかったもう1つの特徴量を使用して分割した各組を更に分割する。ここで2つの特徴量を同時に使用して分割しても良さそうだが、決定木はオッカムの剃刀という「ある事柄を説明するためには、必要以上に多くを仮定するべきでない」という考えに基づいているためこのようにしている。まあ確かにできるだけ少ない特徴量で決定木が作れたほうが過学習が防げるし、解釈もしやすいと考えられる。(後は計算が簡単になるというのもあるのかな。)
さて、ここで問題となるのは始めの分割のときに2つの特徴量からどちらを選ぶかである。言い換えると良い分割基準をどうやって決めるのかということである。次にこれを考える。
分割基準の選び方
分類基準の良し悪しはどう評価すれば良いだろうか?単純なアイディアの1つは、正しく分類されたデータの割合で評価することである。しかし、計算コストが高い問題がある。具体的には、基準が連続値の特徴量以上であった場合に、基準量以上がどのクラスかを全て試す必要がある。例えば、最初の例のようにがく片長が基準値以上をセトナとしたときと、基準値未満をセトナとしたときで精度を比較し、精度が良かった方を採用する必要がある。これはクラス数が3に増加した場合は以下のように複数のパターンで精度を比較し、最も精度が良かったパターンを採用することとなる。
これらのパターンはクラス数・特徴量の数が増加するに従いどんどん増加するため、計算コストが大きくなってしまう。この問題を回避する方法を考える。最初の例のようにがく片長を基準にセトナかどうかを分類する場合を考える。勘の良い方は気づかれているかもしれないが、精度を求めずとも基準値で分割したときにセトナかどうかが分割できてさえいれば良いのである。例えば、5cm以上がセトナで、5cm未満がセトナであるならば、5cmという基準値さえ求められれば教師データがあるため、どちらの領域がセトナなのか判断できるのである。すなわち、入力データを基準値で分割したときに、分割された各領域に属するクラスが同一であれば良いのである。
というわけで評価関数としては分割された領域ごとの同じクラスに属するデータの割合を使用すれば良さそうである。つまり、最も良く分類できている場合は各領域で同じクラスに属するデータの割合は100%となり、そこから真の基準からずれるごとに各領域の同じクラスに属するデータの割合は減少していくことがわかる。因みにこのどれだけ各領域で同じクラスにデータが属しているかという指標を純度と呼ぶ。最適化アルゴリズムでは一般に最小化問題を解くことを考えるため、実際のアルゴリズムでは評価関数として**不純度(どれだけ各領域で同じクラスにデータが属していないか)**を使用する。因みに不純度としては、1.10. Decision Trees — scikit-learn 0.23.1 documentationによるとGini、エントロピー、Misclassificationという指標が使用されるようである。Giniはジニ係数という経済学の分野で用いられる指標から発想を得て作成された。ジニ係数は主に所得の不平等さを表す指標であるそうだ(勉強不足のため詳しいことはわかりません)。エントロピーは情報の乱雑さを表す指標であるため、各領域にどれだけ同じクラスでないデータがあるかを評価する指標に使えそうだと思えるだろう。Misclassificationはおそらく一番分かりやすい指標であり、各領域での同じクラスに属するデータの割合の最大値を純度の代表値として、各領域で同じクラスに属さないデータの割合を近似したものと考えられる(「クラス数 - 各領域で同じクラスに属さないデータの割合の合計値」としても良さそうだがそうしない理由があるのだろうか?)。
決定木作成アルゴリズム
ここまでの議論で決定木をどう作成すれば良いかがほとんどわかった。次に実際にアルゴリズムを作成する。入力データはセトナ、バーシクル、バージニカのがく片長、がく片幅とする。予測すべき値はセトナ、バーシクル、バージニカのどれに分類できるかとする。
1つ追加で考える必要があることがある。連続値の特徴量をどう扱うかという問題である。特徴量がカテゴリカルなものであれば、不純度が最小になる特徴量を選択し、選択された特徴量に従い分割すれば良い。そこで、連続値の特徴量からカテゴリカルな特徴量を作成する。具体的には、始めの例のように、いくつか分割基準をサンプルしそれぞれをカテゴリカルな特徴量とする。それぞれについて不純度を計算し、不純度が最小となる特徴量で分割する。今回は簡単のため2分割しかしないこととする。
具体的なアルゴリズムの手順は以下の通りである。
- 各特徴量について分割基準をいくつかサンプルしカテゴリカルな特徴量を生成する
- 各カテゴリカルな特徴量について不純度を計算する
- 不純度が最小となる特徴量を選択しデータを分割する
- 分割された2つのデータ集合それぞれに対して 1. から 4. を繰り返す
上記の手順よりこのアルゴリズムは再帰的なアルゴリズムであることが分かる。
また、分割の終了条件を考える必要がある。ここでは、以下を終了条件とする。
- 分割後の集合のデータが全て同じクラスに属する場合
- 選択できる特徴量がなくなった場合
- 分割後の集合にデータが1つもない場合
まず具体的なデータセットを用意する。今回は意図的にデータを作成する。
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
# Sample from:
d = 2 # Number of dimensions
mean = np.matrix([[0.], [1.]])
covariance = np.matrix([
[1, 0.8],
[0.8, 1]
])
# Create L
L = np.linalg.cholesky(covariance)
# Sample X from standard normal
n = 50 # Samples to draw
np.random.seed(0)
_X = np.random.normal(size=(d, n))
# Apply the transformation
X_setosa = L.dot(_X) + mean
X_setosa = X_setosa.T
y_setosa = np.zeros((n, 1))
# Sample from:
d = 2 # Number of dimensions
mean = np.matrix([[3.], [4.]])
covariance = np.matrix([
[1, 0.8],
[0.8, 1]
])
# Create L
L = np.linalg.cholesky(covariance)
# Sample X from standard normal
n = 50 # Samples to draw
_X = np.random.normal(size=(d, n))
# Apply the transformation
X_versicolor = L.dot(_X) + mean
X_versicolor = X_versicolor.T
y_versicolor = np.ones((n, 1))
X = np.vstack([X_setosa, X_versicolor])
y = np.vstack([y_setosa, y_versicolor])
fig, ax = plt.subplots(figsize=(6, 4.5))
ax.scatter(X[:, 0][y==0].A1, X[:, 1][y==0].A1, c='red', label='setosa')
ax.scatter(X[:, 0][y==1].A1, X[:, 1][y==1].A1, c='blue', label='versicolor')
ax.set_xlabel('Sepal length [cm]', fontsize=13)
ax.set_ylabel('Sepal width [cm]', fontsize=13)
ax.axis([-3.1, 7.1, -3.1, 7.1])
ax.set_aspect('equal')
ax.legend()
plt.show()
次に1. , 2.を実行する。
sepal_len_thresholds = np.arange(-2, 6)
sepal_wid_thresholds = np.arange(-1, 6)
def Entropy(X_class1, X_class2):
p_class1 = 1 / (len(X_class1) + len(X_class2)) * len(X_class1)
p_class2 = 1 / (len(X_class1) + len(X_class2)) * len(X_class2)
if p_class1 == 0:
entropy = -1 * (p_class2 * np.log10(p_class2))
elif p_class2 == 0:
entropy = -1 * (p_class1 * np.log10(p_class1))
else:
entropy = -1 * (p_class1 * np.log10(p_class1) + p_class2 * np.log10(p_class2))
return entropy
fig, axs = plt.subplots(1, len(sepal_len_thresholds), figsize=(30, 30))
axs = axs.ravel()
for i in range(len(sepal_len_thresholds)):
sepal_len_threshold = sepal_len_thresholds[i]
sep_data_left_ind = X[:, 0] <= sepal_len_threshold
sep_data_left_ind = sep_data_left_ind.A1
X_left = X[sep_data_left_ind, :]
y_left = y[sep_data_left_ind]
X_right = X[np.invert(sep_data_left_ind), :]
y_right = y[np.invert(sep_data_left_ind)]
X_class1 = X_left[:, 0][y_left==0].A1
X_class2 = X_left[:, 0][y_left==1].A1
impurity_left = Entropy(X_class1, X_class2)
X_class1 = X_right[:, 0][y_right==0].A1
X_class2 = X_right[:, 0][y_right==1].A1
impurity_right = Entropy(X_class1, X_class2)
total_impurity = impurity_left + impurity_right
axs[i].scatter(X[:, 0][y==0].A1, X[:, 1][y==0].A1, s=15, c='red', label='setosa')
axs[i].scatter(X[:, 0][y==1].A1, X[:, 1][y==1].A1, s=15, c='blue', label='versicolor')
# axs[i].scatter(X_right[:, 0][y_right==0].A1, X_right[:, 1][y_right==0].A1, c='orange', label='setosa')
# axs[i].scatter(X_right[:, 0][y_right==1].A1, X_right[:, 1][y_right==1].A1, c='green', label='versicolor')
axs[i].axvline(x=sepal_len_threshold, ymin=0, ymax=1)
axs[i].set_xlabel('Sepal length [cm]', fontsize=10)
axs[i].set_ylabel('Sepal width [cm]', fontsize=10)
axs[i].axis([-3.1, 7.1, -3.1, 7.1])
axs[i].set_aspect('equal')
axs[i].set_title(f'total impurity = {np.around(total_impurity, 6)}')
axs[i].legend()
plt.show()
fig, axs = plt.subplots(1, len(sepal_wid_thresholds), figsize=(30, 30))
axs = axs.ravel()
for i in range(len(sepal_wid_thresholds)):
sepal_wid_threshold = sepal_wid_thresholds[i]
sep_data_left_ind = X[:, 1] <= sepal_wid_threshold
sep_data_left_ind = sep_data_left_ind.A1
X_left = X[sep_data_left_ind, :]
y_left = y[sep_data_left_ind]
X_right = X[np.invert(sep_data_left_ind), :]
y_right = y[np.invert(sep_data_left_ind)]
X_class1 = X_left[:, 0][y_left==0].A1
X_class2 = X_left[:, 0][y_left==1].A1
impurity_left = Entropy(X_class1, X_class2)
X_class1 = X_right[:, 0][y_right==0].A1
X_class2 = X_right[:, 0][y_right==1].A1
impurity_right = Entropy(X_class1, X_class2)
total_impurity = impurity_left + impurity_right
axs[i].scatter(X[:, 0][y==0].A1, X[:, 1][y==0].A1, s=15, c='red', label='setosa')
axs[i].scatter(X[:, 0][y==1].A1, X[:, 1][y==1].A1, s=15, c='blue', label='versicolor')
# axs[i].scatter(X_right[:, 0][y_right==0].A1, X_right[:, 1][y_right==0].A1, c='orange', label='setosa')
# axs[i].scatter(X_right[:, 0][y_right==1].A1, X_right[:, 1][y_right==1].A1, c='green', label='versicolor')
axs[i].axhline(y=sepal_wid_threshold, xmin=0, xmax=1)
axs[i].set_xlabel('Sepal length [cm]', fontsize=10)
axs[i].set_ylabel('Sepal width [cm]', fontsize=10)
axs[i].axis([-3.1, 7.1, -3.1, 7.1])
axs[i].set_aspect('equal')
axs[i].set_title(f'total impurity = {np.around(total_impurity, 6)}')
axs[i].legend()
plt.show()
次に 3. 不純度が最小となる特徴量を選択しデータを分割するを実施する
sepal_wid_threshold = sepal_wid_thresholds[3]
sep_data_left_ind = X[:, 1] <= sepal_wid_threshold
sep_data_left_ind = sep_data_left_ind.A1
X_node1_left = X[sep_data_left_ind, :]
y_node1_left = y[sep_data_left_ind]
X_node1_right = X[np.invert(sep_data_left_ind), :]
y_node1_right = y[np.invert(sep_data_left_ind)]
fig, axs = plt.subplots(1, 2, figsize=(10, 10))
axs = axs.ravel()
axs[0].scatter(X_node1_left[:, 0][y_node1_left==0].A1, X_node1_left[:, 1][y_node1_left==0].A1, c='red', label='setosa')
axs[0].scatter(X_node1_left[:, 0][y_node1_left==1].A1, X_node1_left[:, 1][y_node1_left==1].A1, c='blue', label='versicolor')
axs[0].set_xlabel('Sepal length [cm]', fontsize=13)
axs[0].set_ylabel('Sepal width [cm]', fontsize=13)
axs[0].axis([-3.1, 7.1, -3.1, 7.1])
axs[0].set_aspect('equal')
axs[0].set_title(f'Left leaf data whose sepal width less than or equal to {sepal_wid_threshold}')
axs[1].scatter(X_node1_right[:, 0][y_node1_right==0].A1, X_node1_right[:, 1][y_node1_right==0].A1, c='red', label='setosa')
axs[1].scatter(X_node1_right[:, 0][y_node1_right==1].A1, X_node1_right[:, 1][y_node1_right==1].A1, c='blue', label='versicolor')
axs[1].set_xlabel('Sepal length [cm]', fontsize=13)
axs[1].set_ylabel('Sepal width [cm]', fontsize=13)
axs[1].axis([-3.1, 7.1, -3.1, 7.1])
axs[1].set_aspect('equal')
axs[1].set_title(f'Left leaf data whose sepal width above {sepal_wid_threshold}')
ax.legend()
plt.show()