1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

CHAIDで決定木。Pythonで

Posted at

はじめに

Pythonを使って、CHAIDの決定木を実行する話。
はまったことも紹介します。

前回の記事のCHAIDやりたい!っていうところから1歩調べてみました。
[前回] sklearnのDecisionTreeClassifierの結果の理解
https://qiita.com/yo16/items/36ec237a574d8ab86a75

使うライブラリはこちら

pip install CHAID

使うデータ=sklearn.datasets.load_iris

sklearnのirisデータを使ってみます。

from sklearn.datasets import load_iris

iris = load_iris()
X = iris.data
y = iris.target

ここで注目。

>>> print(X)
[[5.1 3.5 1.4 0.2]
 [4.9 3.  1.4 0.2]
 [4.7 3.2 1.3 0.2]
 [4.6 3.1 1.5 0.2]
 [5.  3.6 1.4 0.2]
 (省略)
 [6.2 3.4 5.4 2.3]
 [5.9 3.  5.1 1.8]]

1行ずつ、配列になってます。
使用する Rambatino/CHAID は、1列ごとに配列になってないといけないです。ここに結構はまりました。

加工

1列ごとの配列に加工します。

X_t = list(zip(*X))

list(zip(*X))の説明
zip関数に、*Xを渡します。*Xというのは、外側の [ ]を外して渡すというもの。つまり[5.1, 3.5 1.4 0.2] [4.9 3.0 1.4 0.2] ~ [5.9 3. 5.1 1.8]という、行数分の配列(irisの場合は150件)を一気にzipに渡します。そうすると、zipは、1/150から1つ目を、2/150から1つ目を、・・・150/150から1つ目を取り出し、くっつけて150要素の配列にします。つぎに1/150から2つ目を、2/150から2つ目を・・・150/150から2つ目を取り出し、くっつけて150要素の配列にします。・・・を4回繰り返します。
それをlist()すると、150要素ある配列が4つ、配列になります。

>>> print(X_t)
[(5.1, 4.9, 4.7, 4.6, 5.0, (~省略~), 6.2, 5.9),
 (3.5, 3. , 3.2, 3.1, 3.6, (~省略~), 3.4, 3. ),
 (                (~省略~)                   ),
 (0.2, (~省略~)                   , 2.3, 1.8)]

tは、転置行列という意味で名付けてます。

CHAIDで決定木

ここまでできればあとは簡単。

from CHAID import Tree, NominalColumn, OrdinalColumn

# OrdinalColumnが4要素入っている配列を作る
cols = [OrdinalColumn(x, name=f'col_{i}') for i, x in enumerate(X_t)]

# 決定木を実施
tree = Tree(cols, NominalColumn(y), {'min_child_node_size': 5)

Tree()に、説明変数、目的変数の順に入れます。
変数のタイプは、2つあります。
OrdinalColumn()は、順序尺度。数字の大小、順序に意味がある数値。irisのサイズは順序に意味があるのでこちら。
NominalColumn()は、名義尺度。数字に意味はない。irisの0,1,2の順序は関係ないのでこちら。

変数のタイプは、順序尺度であって連続値ではない点には注意です。(column.pyにはContinuousColumn()っていうのがあるんだけど、目的変数にしか使ってないのかも)
数値が、1,2,100の3タイプがあったら、1と2が近くて100は遠いということは意識せず、1,2,3と同じように判断するということになります。危うい。。

確認

print_tree()ってやると、わかりづらいけどツリーが出ます。とりあえず中かっこに着目すると、0,1,2の数がもともと{0: 50.0, 1: 50.0, 2: 50.0}だったものが、col_2、col_0の値によって、分けられてます。
図にしたりする関数もあるけど、そこは別の話なのでこの記事では割愛。

# 確認
tree.print_tree()
#([], {0: 50.0, 1: 50.0, 2: 50.0}, (col_2, p=2.561410449304364e-54, score=256.52173913043475, groups=[[1.9], [3.9, 4.3], [5.2, 6.7]]), dof=4))
#|-- ([1.9], {0: 50.0, 1: 0, 2: 0}, <Invalid Chaid Split> - the node only contains single category respondents)
#|-- ([3.9, 4.3], {0: 0, 1: 48.0, 2: 6.0}, (col_0, p=0.0848944843805779, score=4.9326923076923075, groups=[[4.5], [5.2], [6.8, 7.7]]), dof=2))
#|   |-- ([4.5], {0: 0, 1: 1.0, 2: 1.0}, <Invalid Chaid Split> - the minimum parent node size threshold has been reached)
#|   |-- ([5.2], {0: 0, 1: 25.0, 2: 1.0}, <Invalid Chaid Split> - the minimum parent node size threshold has been reached)
#|   +-- ([6.8, 7.7], {0: 0, 1: 22.0, 2: 4.0}, <Invalid Chaid Split> - the minimum parent node size threshold has been reached)
#+-- ([5.2, 6.7], {0: 0, 1: 2.0, 2: 44.0}, <Invalid Chaid Split> - p-value greater than alpha merge)

まとめ

PythonでCHAIDの決定木(Rambatino/CHAID)を実行する記事は、一貫して説明してるものがあまりなく、少し苦労しました。

注意ポイントは2つありました。

  1. 説明変数の行列の方向・次元の順序(CSVを読むにしても、irisを使うにしても、今回の転置変換は必要です)
  2. 説明変数の連続値は使えない
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?