"python 決定木"とかでググると、世の中には大量の解説記事がありますが、サンプルが複雑だったりして読み解くのが面倒でした。
ということで、自分で解釈した使い方を書いておきます。
なお、決定木が何か?はある程度分かっていて「pythonで決定木を手っ取り早く作りたい」人が対象です。
import sklearn
import sklearn.tree
# 入力データの作成
features = [
# [Width, Height, Channel]
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[6, 2, 1],
]
targets = [
"Red",
"Green",
"Blue",
"Blue",
]
print(features, targets)
# 決定木の作成
tree = sklearn.tree.DecisionTreeClassifier(random_state=0, max_depth=2)
tree.fit(features, targets)
# 決定木の可視化
print(sklearn.tree.export_text(tree, feature_names=["Width", "Height", "Channel"]))
# 決定木の実行結果
print(tree.predict(features))
入力データの作成
入力データには2種類あります。1つは分類に使う特徴量(feature)、もう1つは分類結果(target)です。
それぞれ配列として持っておきます(numpy.arrayとかでも大丈夫です)。
今回は、適当に、幅・高さ・チャンネルという特徴量を作ってみて、分類結果は赤・緑・青・青になるという入力データにしました。
中身は適当です。
printすると
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [6, 2, 1]] ['Red', 'Green', 'Blue', 'Blue']
と出てくるだけですね。
決定木の作成
pythonで決定木を作るなら、sklearn.tree.DecisionTree*が一番楽っぽいです。
今回は分類なのでDecisionTreeClassifierを使います。
DecisionTreeClassifier()
に渡す引数はいっぱいあるので、そこは使い方に合わせて調べてもらうとして、重要な引数は2つ
-
random_state
:乱数を使うところのシードを固定しておきましょう。そうしないと実行のたびに結果が変わってしまうので・・・。値はなんでもいいので0とか入れておけばいいです。 -
max_depth
:木の深さです。指定しないと実用的にならないので、適当に指定しましょう。
作ったらfit()
を呼ぶだけですね。
なお、なんか世の中にはtargetに渡すものを文字列ではなくわざわざ整数の疑似変数にしているサンプルが転がっていますが、分類名をそのまま渡しても動きます(たぶん古い書き方?)。
決定木の可視化
作った決定木を可視化しましょう。
手っ取り早く確かめたい場合はexport_text()
でテキストで出力させるのが楽です。
ここでfeature_names
引数に特徴量の名前(今回は幅・高・チャンネル)を渡してやると、結果が見やすくなります(書かないと、feature_0, feature_1
とか味気ない名前が適当に付けられます)。
結果をprintしてやると
|--- Width <= 5.00
| |--- Channel <= 4.50
| | |--- class: Red
| |--- Channel > 4.50
| | |--- class: Green
|--- Width > 5.00
| |--- class: Blue
と出てきます。これは決定木そのもので解説することもないかなと思います。
なお、テキストだとこれ以上の結果は出てこないので、ジニ係数とかサンプル数とかもっと細かく結果を可視化したい場合は、export_graphviz()
を使うのが便利です。細かくは解説しませんが
sklearn.tree.export_graphviz(tree, out_file="tree.dot", feature_names=["Width", "Height", "Channel"])
とするとtree.dotというファイルが保存されるので、
$ dot -Tpng tree.dot -o tree.png
決定木の実行結果
最後に決定木に特徴量を入力した結果を見てみましょう。単にpredict()
に特徴量を突っ込むだけです。
printすると
['Red' 'Green' 'Blue' 'Blue']
とか出てきます。ちゃんと分類できてますね。
まとめ
ということで、pandasとかでなんか色々やらなくてもscikit-learnで決定木は一瞬で作れますよ、という解説でした。