1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

本記事では、学習済み決定木モデルの推論をSQLのクエリに変換する、一番シンプルなpythonコードを解説します。

つまり、やりたいことはこういうことです。

image.png

決定木は比較的シンプルなロジックなので、推論処理をSQLのCASE式に変換して大規模に実行することが可能です1。なお、今回はscikit-learnの決定木の分類器を対象にしています。

コード

from typing import Optional, List

from sklearn.tree import BaseDecisionTree


def tree2sql(
    model: BaseDecisionTree,
    feature_names: Optional[list[str]] = None,
    col: str = "pred",
    table: str = "my_table",
    tab: int = 4,
) -> str:
    """
    学習済みの scikit-learn 決定木(分類モデル)を SQL のクエリに変換する。

    Args:
        model (BaseDecisionTree): 
            学習済みの `DecisionTreeClassifier` インスタンス。
            回帰木は非サポート。
        feature_names (Optional[List[str]], optional):
            モデルが予測に使用する入力列名のリスト。
            `None` の場合は `["f0", "f1", …]` が自動で設定される。
        col (str, optional): 生成する SQL の出力列の列名。
        table (str, optional): `FROM` 句で使用するテーブル名。
        tab (int, optional): SQL 文中のインデントに用いるスペース数。デフォルトは `4`。

    Returns:
        str: 木モデルの判定ロジックを再現する完全な SQL 文。
        形式は `SELECT … FROM …` 。
    """

    t = model.tree_
    names = feature_names or [f"f{i}" for i in range(t.n_features_in_)]
    classes = model.classes_

    def walk(node: int, depth: int) -> str:
        pad = " " * tab * depth
        if t.children_left[node] == t.children_right[node]:  # leaf
            label = classes[t.value[node][0].argmax()]
            return f"{pad}'{label}'"

        feat = names[t.feature[node]]
        thr = t.threshold[node]
        left_sql = walk(t.children_left[node], depth + 1)
        right_sql = walk(t.children_right[node], depth + 1)

        return (
            f"{pad}CASE WHEN {feat} <= {thr} THEN\n"
            f"{left_sql}\n"
            f"{pad}ELSE\n"
            f"{right_sql}\n"
            f"{pad}END"
        )

    body = walk(0, 1)  # indent the root once for aesthetics
    return f"SELECT\n{body} AS {col}\nFROM {table};"

実行例は下記の通りです。

tree2sql(model, X_columns)

# Output:
# SELECT
#     CASE WHEN petal_width_cm <= 1.550000011920929 THEN
#         CASE WHEN petal_width_cm <= 0.7000000029802322 THEN
#             '0'
#         ELSE
#             '1'
#         END
#     ELSE
#         '2'
#     END AS pred
# FROM iris_df;

なお、irisデータセットを用いた簡単な例はこちらのノートブックで試すことができます2

sklearnの木の内部構造

image.png

sklearnの学習済みの決定木は以下のようにデータを持っています。

  • 根も葉も含めたすべてのノードの数
    • model.tree_.node_count
    • 例: 5
  • 各ノードに対して、左側の子ノードのindex
    • model.tree_.children_left
    • 例: array([ 1, 2, -1, -1, -1], dtype=int64)
    • 上左図の数字が各ノードのindexで、例えばindex=1のノードは左下に②のノードを持っていることがわかります
  • 各ノードに対して、右側の子ノードのindexを持つ
    • model.tree_.children_right
    • 例: array([ 4, 3, -1, -1, -1], dtype=int64)
  • 各ノードに対して、分岐に使われる特徴量の番号
    • model.tree_.feature
    • 例: array([ 3, 3, -2, -2, -2], dtype=int64)
  • 各ノードに対して、分岐に使われる閾値の値
    • model.tree_.threshold
    • 例: array([ 1.55000001, 0.7 , -2. , -2. , -2. ])
  • 各ノードに対して、各クラスの確率(入ってるサンプルの割合に該当)
    • model.tree_.value
    • 例: 以下
array([[[0.23333333, 0.36666667, 0.4       ]],
       [[0.38888889, 0.61111111, 0.        ]],
       [[1.        , 0.        , 0.        ]],
       [[0.        , 1.        , 0.        ]],
       [[0.        , 0.        , 1.        ]]])

変換ロジック

Section1からコード抜粋して、一部コメントを入れる形で解説します。


def walk(node: int, depth: int) -> str:
    pad = " " * tab * depth

    # part1: 葉ノードの場合、該当する予測クラスを返す
    if t.children_left[node] == t.children_right[node]:  # leaf
        label = classes[t.value[node][0].argmax()]
        return f"{pad}'{label}'"

    feat = names[t.feature[node]]
    thr = t.threshold[node]
    # part2: 子ノードを再帰的に探索
    left_sql = walk(t.children_left[node], depth + 1)
    right_sql = walk(t.children_right[node], depth + 1)

    # part3: 分岐に使われている特徴量と閾値からCASE式の条件文を書き、
    # THEN以下は再帰的に取得したクエリを当てはめる
    return (
        f"{pad}CASE WHEN {feat} <= {thr} THEN\n"
        f"{left_sql}\n"
        f"{pad}ELSE\n"
        f"{right_sql}\n"
        f"{pad}END"
    )

# 根ノードから探索を開始する
body = walk(0, 1)  # indent the root once for aesthetics

# CASE式以外のクエリをくっつける
return f"SELECT\n{body} AS {col}\nFROM {table};"
  1. おそらく世の中にある各種ツールでも似たような形で実装されていると思います。なお、今回のpythonコードは https://github.com/hyperforest/tree2query/blob/main/tree_parser.py を参考にChatGPTに書かせたものをベースにしています。

  2. Github gistのURLはこちらです → https://gist.github.com/tanaka-jin/dff3c6056351141ecd2803e0d57e1de6

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?