3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

KLab EngineerAdvent Calendar 2023

Day 22

Python/C API で 素朴な二分探索木を書いてみる

Last updated at Posted at 2023-12-21

年末が近づくと Python/C API を無駄に使いたくなる衝動は今年も無事に(?)発生。なので素朴な二分探索木でも書いてみます。

from collections.abc import Collection, Iterator
from typing import TypeVar

T = TypeVar('T')


class BinarySearchTree(Collection[T]):
    ...

まずは Pure Python で書き下す

Poetry new で pyproject.toml

vim a.py で雑に始めて、コードが育ってきたら setup.py, setup.cfg, requirements.txt, MANIFEST.in を追加、でも動きはするのですが。 Poetry新モジュールのスケルトン を作ってもらうことにします。 requirements.txt からの卒業、 pyproject.toml への入門。とりあえず pytest と ipython と mypy をインストール。開発にしか使わないライブラリなので 依存関係を整理する 機能を用いて dev という名のグループにいれておきます。これでいらないときには poetry install --without dev とするなどの回避手段がとれるようになりました。開発開始。

poetry new fgshun-bst
cd fgshun-bst
poetry add pytest ipython mypy --group dev

でき上がったファイルたち。

fgshun-bst
|--README.md
|--fgshun_bst
|  |--__init__.py
|--poetry.lock
|--pyproject.toml
|--tests
|  |--__init__.py

開発開始

二分探索木を実装していきます。データ構造の教科書の序盤にでてきそうな、「最悪ケースだと線形リストと変わらない。なので……」という話の起点になりそうなものです。要素の追加ができて (add) 、要素のあるなしを知るための探索ができて (__contains__) 、要素数が分かって (__len__) 、要素を小さいほうから全列挙するイテレータを得ることができる (__iter__) 、そんなコレクション (collections.abc.Collection) を目指します。

全列挙は。二分探索木も木に違いないので根を起点に深さ優先探索すればよいでしょう。再帰関数、いや再帰ジェネレータで。

要素数は全列挙ができれば数えることができます。問われるたびに全列挙を行うのは効率が悪くはありますが。

探索は探そうとしている値と着目しているノードの値とを比べて、左の子もしくは右の子を新たな着目しているノードに変えていくループを書きます。

追加は同様に着目するノードを変えていき、ノードがないところに加えればよいでしょう。

ソート済み入力に対する性能が悪化する問題を抱えていたり、要素の削除機能が省略されていたりと半端ではありますが。とりあえずはこれで完成とします。

fgshun_bst/bst.py
class Node:
    def __init__(self, value):
        self.value = value
        self.left = None
        self.right = None


class BinarySearchTree:
    def __init__(self):
        self.root = None

    @classmethod
    def _iternodes(cls, node):
        if node is not None:
            yield from cls._iternodes(node.left)
            yield node
            yield from cls._iternodes(node.right)

    def __len__(self):
        nodes = self._iternodes(self.root)
        return sum(1 for node in nodes)

    def __iter__(self):
        nodes = self._iternodes(self.root)
        return (node.value for node in nodes)

    def __contains__(self, value):
        cur = self.root
        while cur:
            if value < cur.value:
                cur = cur.left
            elif cur.value < value:
                cur = cur.right
            else:
                return True
        return False

    def add(self, value):
        cur = self.root
        if cur is None:
            self.root = Node(value)
            return
        while True:
            if value < cur.value:
                if cur.left is None:
                    cur.left = Node(value)
                    return
                else:
                    cur = cur.left
            elif cur.value < value:
                if cur.right is None:
                    cur.right = Node(value)
                    return
                else:
                    cur = cur.right
            else:
                return

テストを書く

コードを書いたらユニットテストも追加しましょう。整数も文字も入れられて。 in 演算で入っている値には True 、入っていない値には False が得られて、全列挙ができてその順はソートされていることを確認。よし! (本当か?)

test/test_bst.py
import string

from fgshun_bst.bst import BinarySearchTree


def test_bst_int():
    values = [3, 1, 4, 1, 5, 9]
    py_set = frozenset(values)

    tree = BinarySearchTree()
    for c in values:
        tree.add(c)

    for i in range(10):
        assert (i in tree) is (i in py_set)


def test_bst_str():
    s = 'fgshun'
    py_set = frozenset(s)

    tree = BinarySearchTree()
    for c in s:
        tree.add(c)

    for c in string.ascii_letters:
        assert (c in tree) is (c in py_set)


def test_bst_iter():
    values = [3, 1, 4, 1, 5, 9]

    tree = BinarySearchTree()
    for c in values:
        tree.add(c)

    for a, b in zip(sorted(frozenset(values)), tree):
        assert a == b

typing で型ヒント

typing モジュールのドキュメントや PEP などを参考に型ヒントを付けていきます。のちに本体を C で書き直す予定なので、型ヒントは stub file (.pyi) に書いていきます。そう、元となるファイルに手を入れずに別ファイルに型ヒントを書いておくことができるのです。この手法は C Extension であるがために元 .py ファイルが存在しない場合にも対応ができます。そして PEP 561 にしたがって py.typed という名の空ファイルをいれておく ことで、型ヒント持ちのモジュールであることを示しておきます。

fgshun_bst/py.typed

fgshun_bst/__init__.py
from .bst import BinarySearchTree
fgshun_bst/__init__.pyi
from collections.abc import Collection, Iterator
from typing import TypeVar

T = TypeVar('T')


class BinarySearchTree(Collection[T]):
    def __init__(self) -> None:
        ...

    def __len__(self) -> int:
        ...

    def __iter__(self) -> Iterator[T]:
        ...

    def __contains__(self, value: T) -> bool:
        ...

    def add(self, value: T) -> None:
        ...

さぁ、 C 言語で書こう

Pure Python で書いていたコードを Python/C API を使う C 言語で書き直していきます。

C でビルドできるようにする

build 用スクリプトとして build.py を用意。 pyproject.toml へ設定。 setuptools を使うため requires に追加。これで poetry install でビルド、開発用インストールができるようになります。

pyproject.toml の一部を抜粋
[tool.poetry.build]
script = "build.py"
generate-setup-file = true

[build-system]
requires = ["poetry-core", "setuptools"]
build-backend = "poetry.core.masonry.api"
build.py
from setuptools import Extension
from setuptools.command.build_ext import build_ext

extensions = [
    Extension("fgshun_bst._bst", sources=["fgshun_bst/_bst.c"]),
]

def build(setup_kwargs):
    setup_kwargs.update({"ext_modules": extensions, "cmdclass": {"build_ext": build_ext}})

構造体の宣言

まずは二分探索木そのものを表す BinarySearchTreeObject と、その Node を宣言します。 BinarySearchTreeObject は Python 側からオブジェクトとして見える必要があるため、先頭に PyObject_HEAD マクロを入れて必要となるメンバをいれておきます。

typedef struct NODE {
    PyObject *value;
    struct NODE *left;
    struct NODE *right;
} Node;


typedef struct {
    PyObject_HEAD
    Node *root;
} BinarySearchTreeObject;

循環参照ガベージコレクション対応

二分探索木ことオブジェクトを抱えるオブジェクトを実装するためには、循環参照ガベージコレクション対応であることを伝える設定が、そして循環参照に備えるために抱えているオブジェクトをインタプリタに通知する仕組み、抱えているオブジェクトへの参照を破棄する仕組みが必要になります。

まずは __new__ にて。メモリを確保するにあたって PyObject_New ではなく PyObject_GC_New を使います。そして、初期化後には PyObject_GC_Track にてガベージコレクタの監視対象に加えます。

static PyObject *
bst_new(PyTypeObject *subtype, PyObject *args, PyObject *kwargs)
{
    static char *kwlist[] = {NULL};
    BinarySearchTreeObject *self;
    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "", kwlist)) { return NULL; }

    self = PyObject_GC_New(BinarySearchTreeObject, subtype);
    if (!self) { return NULL; }
    self->root = NULL;
    PyObject_GC_Track(self);

    return (PyObject *)self;
}

抱えているオブジェクトすべてをたぐれるようにしておかなければなりません、 Py_VISIT マクロ関数の対象に取ることで。

static int
bst_traverse_inner(Node *node, visitproc visit, void *arg)
{
    if (node == NULL) { return 0; }

    bst_traverse_inner(node->left, visit, arg);
    bst_traverse_inner(node->right, visit, arg);
    Py_VISIT(node->value);
    return 0;
}


static int
bst_traverse(BinarySearchTreeObject *self, visitproc visit, void *arg)
{
    return bst_traverse_inner(self->root, visit, arg);
}

抱えているオブジェクトへの参照を破棄できる必要があります、 Py_CLEAR マクロ関数の対象に取ることで。

static void
bst_clear_inner(Node *node)
{
    if (node == NULL) { return; }

    bst_clear_inner(node->left);
    bst_clear_inner(node->right);
    Py_CLEAR(node->value);
}


static int
bst_clear(BinarySearchTreeObject *self)
{
    bst_clear_inner(self->root);
    return 0;
}

自身を破棄する際にも注意がいります。まずはガベージコレクションの監視対象から外す必要があります。これには __new__ で用いた PyObject_GC_Track と対になる PyObject_GC_UnTrack を用います。次に抱えている他のオブジェクトへの参照を破棄する必要があります。これには先ほど作ったクリア用関数がそのまま使えます。最後に、PyObject_GC_New で得たメモリを破棄します。これも __new__ で用いた PyObject_GC_New と対になる PyObject_GC_Del を用います。Node 用に抱えたメモリの解法処理も必要となりますが、こちらの解説は追加処理の移植にて説明します。

static void
bst_dealloc_inner(Node *node)
{
    if (node == NULL) { return; }

    bst_dealloc_inner(node->left);
    node->left = NULL;
    bst_dealloc_inner(node->right);
    node->right = NULL;
    PyMem_Free(node);
}


static void
bst_dealloc(BinarySearchTreeObject *self)
{
    PyObject_GC_UnTrack(self);
    bst_clear(self);
    bst_dealloc_inner(self->root);
    self->root = NULL;
    PyObject_GC_Del(self);
}

最後に。この型が循環ガベージコレクションの監視対象であると設定します、クラス設定 PyType_Spec の flags に Py_TPFLAGS_HAVE_GC を設定することで。

static PyType_Spec bst_spec = {
    .name = "fgshun_bst.BinarySearchTree",
    .basicsize = sizeof(BinarySearchTreeObject),
    .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC,
};

全列挙処理の移植

全列挙処理を移植していきます。とはいえ、処理の中断と再会の実装は手間なので、再帰関数で二分探索木をたぐりつつ組み込み型 list に一度ため込んでしまうようにしました。できあがった list のイテレータを返します。 list の作成には PyList_New を、 list へのオブジェクトの追加は PyList_Append を使います。これらの関数たちを使う際にはエラーに終わった際の対応が必要となります。エラー発生時にどのような戻り値が得られるかは公式ドキュメントに記載があります。

static int
bst_iter_inner(Node *node, PyObject *temp)
{
    if (node == NULL) { return 0; }

    if (bst_iter_inner(node->left, temp) == -1) {
        return -1;
    }
    if (PyList_Append(temp, node->value) == -1) {
        return -1;
    }
    if (bst_iter_inner(node->right, temp) == -1) {
        return -1;
    }
    return 0;
}


static PyObject *
bst_iter(BinarySearchTreeObject *self)
{
    PyObject *temp;
    if (!(temp = PyList_New(0))) { return NULL; }

    if (bst_iter_inner(self->root, temp) == -1) {
        Py_XDECREF(temp);
        return NULL;
    }

    PyObject *iter = PyObject_GetIter(temp);
    Py_XDECREF(temp);
    return iter;
}

探索処理の移植

探索処理の移植のためには Python オブジェクトたちの比較処理がいります。探そうとしている値と着目しているノードの値との大小関係が分かる必要があるのですから。このためには PyObject_RichCompareBool を用います。これも失敗しうるためエラー対応が必要です。失敗時の戻り値は -1 です、 switch の default で拾います。

static int
bst_contains(BinarySearchTreeObject *self, PyObject *value)
{
    Node *cur = self->root;
    while (cur) {
        switch (PyObject_RichCompareBool(value, cur->value, Py_LT)) {
        case 1:
            cur = cur->left;
            break;
        case 0:
            switch (PyObject_RichCompareBool(cur->value, value, Py_LT)) {
            case 1:
                cur = cur->right;
                break;
            case 0:
                return 1;
            default:
                return -1;
            }
            break;
        default:
            return -1;
        }
    }
    return 0;
}

追加処理の移植

追加の処理は探索の処理と似たようなものになります。 Node 用のメモリの確保と破棄には C 標準の malloc, calloc, free の代わりに PyMem_Malooc, PyMem_Calloc, PyMem_Free を用いることができます。これは Python インタプリタが確保、監視している Python ヒープ領域からメモリを借り受けることができるものです。

/* 再掲、自身を破棄する処理、その再帰部分 */
static void
bst_dealloc_inner(Node *node)
{
    if (node == NULL) { return; }

    bst_dealloc_inner(node->left);
    node->left = NULL;
    bst_dealloc_inner(node->right);
    node->right = NULL;
    PyMem_Free(node);
}
/* 再掲、ここまで */

static PyObject *
bst_add(BinarySearchTreeObject *self, PyObject *value)
{
    Node *cur = self->root;
    if (cur == NULL) {
        cur = PyMem_Calloc(sizeof(Node), 1);
        if (cur == NULL) { return NULL; }
        Py_INCREF(value);
        cur->value = value;
        self->root = cur;
        Py_RETURN_NONE;
    }

    while (1) {
        switch (PyObject_RichCompareBool(value, cur->value, Py_LT)) {
/* 略 */

型ヒント対応、ジェネリックとするには

こうして C でつくったクラスはジェネリッククラスではないためカギカッコによる装飾ができない状態にあります。なので PEP 560 を参考に __class_getitem__ を実装しておきます。

static PyObject *
simple_class_getitem(PyObject *type, PyObject *item)
{
    Py_INCREF(type);
    return type;
}

モジュール、クラスの設定

作ってきた関数たちをしかるべき場所へ設定していきます。 PyMethodDef, PyType_Slot, PyType_Spec, PyModuleDef_Slot, PyModuleDef 。そして PyModuleDef_Init を呼ぶ関数を作成。

static PyMethodDef bst_methods[] = {
    {"add", (PyCFunction)bst_add, METH_O, NULL},
    {"__class_getitem__", simple_class_getitem, METH_O|METH_CLASS, NULL},
    {NULL, NULL, 0, NULL} /* Sentinel */
};


static PyType_Slot bst_slots[] = {
    {Py_tp_methods, bst_methods},
    {Py_tp_new, bst_new},
    {Py_tp_iter, (getiterfunc)bst_iter},
    {Py_tp_dealloc, bst_dealloc},
    {Py_tp_traverse, bst_traverse},
    {Py_tp_clear, bst_clear},
    {Py_sq_contains, bst_contains},
    {Py_sq_length, bst_len},
    {0, 0},
};


static PyType_Spec bst_spec = {
    .name = "fgshun_bst.BinarySearchTree",
    .basicsize = sizeof(BinarySearchTreeObject),
    .flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC,
    .slots = bst_slots,
};


static int moddef_exec(PyObject *module) {
    PyObject *bst_type;

    bst_type = PyType_FromSpec(&bst_spec);
    if (!bst_type) { goto error; }
    if (PyModule_AddObject(module, "BinarySearchTree", bst_type)) { goto error; }
    return 0;
error:
    Py_XDECREF(bst_type);
    Py_DECREF(module);
    return -1;
}


static PyModuleDef_Slot moddef_slots[] = {
    {Py_mod_exec, moddef_exec},
    {0, NULL}
};


static struct PyModuleDef moddef_module = {
    PyModuleDef_HEAD_INIT,
    .m_name = "_bst",
    .m_slots = moddef_slots,
};


PyMODINIT_FUNC PyInit__bst(void) {
    return PyModuleDef_Init(&moddef_module);
}

Pure Python 実装を C 実装に取り換える

bst.py のクラスを指していたところを _bst モジュールのを指すように変更しました。こうしても変わらずに動作することを確認しました。

fgshun_bst/__init__.py
from ._bst import BinarySearchTree

全コードへのリンク

おわりに

Python/C API で遊んでみるのは楽しいものです。 CPython の実装を、裏側を知ることにもつながります。年末にやってみることに加えてみてはいかがでしょうか。

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?