本記事では、学習済み決定木モデルの推論をSQLのクエリに変換する、一番シンプルなpythonコードを解説します。
つまり、やりたいことはこういうことです。
決定木は比較的シンプルなロジックなので、推論処理をSQLのCASE式に変換して大規模に実行することが可能です1。なお、今回はscikit-learnの決定木の分類器を対象にしています。
コード
from typing import Optional, List
from sklearn.tree import BaseDecisionTree
def tree2sql(
model: BaseDecisionTree,
feature_names: Optional[list[str]] = None,
col: str = "pred",
table: str = "my_table",
tab: int = 4,
) -> str:
"""
学習済みの scikit-learn 決定木(分類モデル)を SQL のクエリに変換する。
Args:
model (BaseDecisionTree):
学習済みの `DecisionTreeClassifier` インスタンス。
回帰木は非サポート。
feature_names (Optional[List[str]], optional):
モデルが予測に使用する入力列名のリスト。
`None` の場合は `["f0", "f1", …]` が自動で設定される。
col (str, optional): 生成する SQL の出力列の列名。
table (str, optional): `FROM` 句で使用するテーブル名。
tab (int, optional): SQL 文中のインデントに用いるスペース数。デフォルトは `4`。
Returns:
str: 木モデルの判定ロジックを再現する完全な SQL 文。
形式は `SELECT … FROM …` 。
"""
t = model.tree_
names = feature_names or [f"f{i}" for i in range(t.n_features_in_)]
classes = model.classes_
def walk(node: int, depth: int) -> str:
pad = " " * tab * depth
if t.children_left[node] == t.children_right[node]: # leaf
label = classes[t.value[node][0].argmax()]
return f"{pad}'{label}'"
feat = names[t.feature[node]]
thr = t.threshold[node]
left_sql = walk(t.children_left[node], depth + 1)
right_sql = walk(t.children_right[node], depth + 1)
return (
f"{pad}CASE WHEN {feat} <= {thr} THEN\n"
f"{left_sql}\n"
f"{pad}ELSE\n"
f"{right_sql}\n"
f"{pad}END"
)
body = walk(0, 1) # indent the root once for aesthetics
return f"SELECT\n{body} AS {col}\nFROM {table};"
実行例は下記の通りです。
tree2sql(model, X_columns)
# Output:
# SELECT
# CASE WHEN petal_width_cm <= 1.550000011920929 THEN
# CASE WHEN petal_width_cm <= 0.7000000029802322 THEN
# '0'
# ELSE
# '1'
# END
# ELSE
# '2'
# END AS pred
# FROM iris_df;
なお、irisデータセットを用いた簡単な例はこちらのノートブックで試すことができます2。
sklearnの木の内部構造
sklearnの学習済みの決定木は以下のようにデータを持っています。
- 根も葉も含めたすべてのノードの数
model.tree_.node_count
- 例:
5
- 各ノードに対して、左側の子ノードのindex
model.tree_.children_left
- 例:
array([ 1, 2, -1, -1, -1], dtype=int64)
- 上左図の数字が各ノードのindexで、例えばindex=1のノードは左下に②のノードを持っていることがわかります
- 各ノードに対して、右側の子ノードのindexを持つ
model.tree_.children_right
- 例:
array([ 4, 3, -1, -1, -1], dtype=int64)
- 各ノードに対して、分岐に使われる特徴量の番号
model.tree_.feature
- 例:
array([ 3, 3, -2, -2, -2], dtype=int64)
- 各ノードに対して、分岐に使われる閾値の値
model.tree_.threshold
- 例:
array([ 1.55000001, 0.7 , -2. , -2. , -2. ])
- 各ノードに対して、各クラスの確率(入ってるサンプルの割合に該当)
model.tree_.value
- 例: 以下
array([[[0.23333333, 0.36666667, 0.4 ]],
[[0.38888889, 0.61111111, 0. ]],
[[1. , 0. , 0. ]],
[[0. , 1. , 0. ]],
[[0. , 0. , 1. ]]])
変換ロジック
Section1からコード抜粋して、一部コメントを入れる形で解説します。
def walk(node: int, depth: int) -> str:
pad = " " * tab * depth
# part1: 葉ノードの場合、該当する予測クラスを返す
if t.children_left[node] == t.children_right[node]: # leaf
label = classes[t.value[node][0].argmax()]
return f"{pad}'{label}'"
feat = names[t.feature[node]]
thr = t.threshold[node]
# part2: 子ノードを再帰的に探索
left_sql = walk(t.children_left[node], depth + 1)
right_sql = walk(t.children_right[node], depth + 1)
# part3: 分岐に使われている特徴量と閾値からCASE式の条件文を書き、
# THEN以下は再帰的に取得したクエリを当てはめる
return (
f"{pad}CASE WHEN {feat} <= {thr} THEN\n"
f"{left_sql}\n"
f"{pad}ELSE\n"
f"{right_sql}\n"
f"{pad}END"
)
# 根ノードから探索を開始する
body = walk(0, 1) # indent the root once for aesthetics
# CASE式以外のクエリをくっつける
return f"SELECT\n{body} AS {col}\nFROM {table};"
-
おそらく世の中にある各種ツールでも似たような形で実装されていると思います。なお、今回のpythonコードは https://github.com/hyperforest/tree2query/blob/main/tree_parser.py を参考にChatGPTに書かせたものをベースにしています。 ↩
-
Github gistのURLはこちらです → https://gist.github.com/tanaka-jin/dff3c6056351141ecd2803e0d57e1de6 ↩