概要
以下の「DBMSをGoで実装してみた」という記事が面白かったので、C++でDBMSを自作してみた。元記事で開発したDBMS (bogoDBと呼ばれている)をベースにしつつ、B木ではなくB+木を使ったインデクスや複数ページの管理、Foat型、複数列・テーブルの同時選択、ファイル実行、範囲検索のなど様々な機能を追加した。
ただ、単にDBMSを作っただけでは物足りない気がしたので、以前から気になっていたRainというSQLクエリをインターフェースとして機械学習モデルのデバック & 異常な学習データの除去を行うシステムもフルスクラッチで構築してみた。
実際の使用例を先に示すと、以下のデータセットbankrupt
の顧客の年齢と負債に基づいて、各顧客が倒産するかどうかを分類するモデルを作りたいとき、
>>Select * From bankrupt
id age debt y
1 40 0 0
2 21 10 0
3 22 10 0
4 32 30 0
5 44 50 1
6 30 100 1
7 63 310 1
8 53 420 1
9 39 530 1
10 49 1000 0
下記のQLライクなクエリを用いることでロジスティクス回帰モデルを訓練することができる。
>>Logreg lrmodel id y 100 1 From Select * From bankrupt
ただ、Select
の結果に注目すると、debt
が大きくなるにつれて倒産リスクが高まっているように見えるにもかかわらず、10番目のレコードのラベルが0であるがために、debt
に対応する係数がマイナスになってしまい、学習済みモデルの精度が低くなっている。
このような場合に、クライアントがComplaint
というクエリを発行し、「debt
が100以上である顧客のラベルは1である可能性が高い」ということを教えると、サーバー側で自動的に問題のあるレコードを自動除去し、モデルを学習しなおすことができる。
>>Complaint comp Shouldbe 1 Remove 1 Against Logreg lrmodel id y 100 1 From Select * From bankrupt Where debt Geq 100
ちなみに、このコードは以下のレポジトリで管理している。
自作DBMS
基本的なアーキテクチャはbogoDBの実装を参考にして各モジュールを実装しつつ、様々な機能を追加するためにデータ構造やエンジンのアルゴリズムを改変していった。また、httpとjson周りはヘッダーオンリーのライブラリを利用し、シリアライズにはbogoDBと同様にprotobufを用いた。
今回実装したDBMSは以下のサブモジュール群から成り立っており、ここではその概要のみを説明する。具体的なコードは前述のGitHubレポジトリで全て管理している。
src/
├── compiler
├── core
├── storage
└── utils
compiler
このサブモジュールでは、SQLクエリのパースを行うプロセルを実装している。bogoDBをベースにしつつ、小数点のサポートや複数テーブル・列の選択、後に説明する機械学習用クエリのサポートなどを行った。構文解析はトップダウンアプローチを用いており、例えばInsert
は以下のように実装されている。
# src/compiler/parse.h
struct Parser {
InsertStmt *insertTableStmt() {
expect(INTO);
Token *tblName = expect(STRING);
if (!tblName) {
return nullptr;
}
expect(VALUES);
expect(LPAREN);
std::vector<Expr *> exprs;
while (true) {
exprs.push_back(eq());
if (consume(RPAREN)) {
break;
}
expect(COMMA);
}
InsertStmt *insertNode = new InsertStmt();
insertNode->TableName = tblName->str;
insertNode->Values = exprs;
return insertNode;
}
};
core
このサブモジュールでは、クライアント側のフロント部分とサーバ側のクエリ実行エンジンを実装した。一般的なDBMSと同様に、インタラクティブなクエリの実行・結果表示および、source
コマンドを用いたファイル実行をサポートしている。サーバ側の実行エンジンはExecutor
と呼ばれる構造体で定義されており、例えばInsert
の実行は以下のように実装されている。
// Executor
struct Executor {...};
inline ResultSet *Executor::insertTable(InsertQuery *q, Transaction *tran) {
...
if (!inTransaction) {
tran = beginTransaction();
}
storage::Tuple *t = NewTuple(tran->Txid(), q->Values);
std::pair<int, int> tid = storage->InsertTuple(q->table->Name, t);
storage->InsertIndex(q->Index, t->data(0).toi(), tid);
if (!inTransaction) {
commitTransaction(tran);
}
...
}
inline ResultSet *Executor::executeMain(Query *q, Plan *p, Transaction *tran) {
...
if (auto insertQuery = dynamic_cast<InsertQuery *>(q)) {
return insertTable(insertQuery, tran);
}
...
}
storage
このサブモジュールでは、タプル、ページ、バッファプール、ストレージなどデータ格納のための機能を実装している。bogoDBでは一枚のページしか実装しておらず最大で32個の要素しか一つのテーブルに格納できなかったため、これを改良して複数ページにまたがるデータ管理も実装した。
Executorは基本的にストレージのAPIを用いてデータの読み取り・書き出しを行っており、ストレージがバッファプールもしくはストレージから適切なタプルやインデクスを取得する。
# src/storage/storage.h
class Storage {
public:
void insertPage(const std::string &tableName, PageStatus pstatus) {
...
}
TID InsertTuple(const std::string &tablename, storage::Tuple *t) {
...
}
BPlusTreeMap<int, TID> *CreateIndex(const std::string &indexName) {
...
}
void InsertIndex(const std::string &indexName, int item, TID &tid) {
...
}
BPlusTreeMap<int, TID> *ReadIndex(const std::string &indexName) {
...
}
storage::Tuple *ReadTuple(const std::string &tableName, TID tid) {
...
}
void Terminate() {
...
}
};
class BufferPool {
public:
int FrontPgid(const std::string &tableName) {
...
}
int NewPgid(const std::string &tableName) {
...
}
Page *readPage(const std::string &tableName, uint64_t tid) {
...
}
std::pair<bool, std::pair<int, int>> appendTuple(const std::string &tableName,
storage::Tuple *t) {
...
}
std::pair<bool, Page *> putPage(const std::string &tableName, uint64_t pgid,
Page *p) {
...
};
#src/storage/disk.h
class DiskManager {
public:
Page *fetchPage(const std::string &dirPath, const std::string &tableName,
uint64_t pgid) {
...
}
void persist(const std::string &dirName, const std::string &tableName,
uint64_t pgid, const Page *page) {
...
}
BPlusTreeMap<int, TID> *readIndex(const std::string &indexName) {
...
}
void writeIndex(const std::string &dirPath, const std::string &indexName,
BPlusTreeMap<int, TID> *tree) {
...
}
};
utils
このサブモジュールでは、その他のデータ構造などを定義しており、LRUキャッシュやB+木を実装した。B+木を使うことで、効率よく範囲検索ができるようになった。
Rain (機械学習モデル向けインタラクティブデバッガ)
RainはSQLライクなクエリを用いてユーザーがインタラクティブに機械学習モデルのデバッグ・再学習を行うことを可能にするシステムである。具体的なプロセスは、(1 機械学習モデルの望ましくない挙動を見つけ、(2 ユーザーがそれを指摘する苦情 (complaint)を伝え、(3 その苦情を微分可能な形に変換した上で、(4 苦情をより小さくするように学習データから一部のデータを取り除く、というものである。苦情の例としては、冒頭の例のような「負債が100以上のレコードはラベル1として識別される方が望ましい」などの形が想定されているが、より具体的に「負債が100以上のレコードのうちラベル0として識別される件数を最小化したい」や「負債が100以上のレコードに対するラベル0の予測確率の合計を最小化したい」など何らかの集計処理を行うことが考えられている。
そして、Rainでは苦情に対する各データの影響度、つまりどのレコードが苦情を引き起こしているのかを、以下の式で定義している。ただし$x$は訓練データ、$N$は訓練データの総数、$\theta$はモデルのパラメータ、$\ell$は損失関数である。
苦情に対する各データの影響度 = Q'H^{-1}E
Q' = \frac{\partial 苦情}{\partial \theta}, \quad H = \frac{\partial^{2} \sum^{N}_{i=1} \ell(\theta, x_i)}{\partial^{2} \theta} \quad E = \{\frac{\partial \ell(\theta, x_i)}{\partial \theta}\}^{N}_{i=1}
つまり、$H$は学習データ全体における損失に対するパラメータの勾配、$E$は一つ一つのレコードにおけるパラメータの勾配である。$Q'$の具体的な式は苦情の形式ごとに異なるが、「負債が100以上のレコードに対するラベル0の予測確率の合計を最小化したい」という例であれば、
Q' = \frac{\partial \sum_{i \in 負債が100以上} P(y_{i} = 0 | x_{i}, \theta)}{\partial \theta}
と定式化することができる。このようにして各データの影響度を求め、影響度が大きいものから順に$k$個のデータを訓練データセットから取り除き、モデルを再学習する。
今回は外部ライブラリをなるべく用いずに実装したかったので、ロジスティクス回帰を一から実装して、SQLライクなクエリで呼び出せるようにした。Logreg
は指定されたデータに対してロジスティクス回帰を実行するクエリであり、勾配降下法を用いた最適化を行っているためそのイテレーション回数と学習率を指定することができる。このクエリの結果としてはモデルのパラメータ、AUC、および個々のデータに対する予測値がどのテーブルに格納されているかが帰ってくる。
>>Logreg `model_name` `primary_key_name` `target_column_name` `number_of_iteration` `learning_rate` From Select `primary_key_name`, `feature_name` From `table_name`
Trained Parameters:
(0) : 5.845724
(1) : -1.187462
(2) : -1.274891
AUC: 0.840000
Predictions on the training data are stored at `prediction_on_training_data_lrmodel`
Rainを用いたデバックも同様にSQLライクなクエリで実現でき、以下のようにWhere condition
で指定されたデータに対する予測値がtarget_class
に近づくようにnumber_of_removed
だけ学習データから問題のあるレコードを取り除いて再学習を行う。
>>Complaint `complaint_name` Shouldbe `target_class` Remove `number_of_removed_records` Against Logreg ... Where `condition`
Fixed Parameters:
(0) : -4.765492
(1) : 8.747224
(2) : 0.744146
AUC: 1.000000
Prediction on the fixed training data is stored at `prediction_on_training_data_comp_lrmodel`
より具体的な実行方法やサンプルデータはGitHubに置いてある。