LoginSignup
6
3

DBMSをC++で自作して、ついでに機械学習モデル用自作デバッガを動かしてみた

Last updated at Posted at 2023-07-12

概要

以下の「DBMSをGoで実装してみた」という記事が面白かったので、C++でDBMSを自作してみた。元記事で開発したDBMS (bogoDBと呼ばれている)をベースにしつつ、B木ではなくB+木を使ったインデクスや複数ページの管理、Foat型、複数列・テーブルの同時選択、ファイル実行、範囲検索のなど様々な機能を追加した。

ただ、単にDBMSを作っただけでは物足りない気がしたので、以前から気になっていたRainというSQLクエリをインターフェースとして機械学習モデルのデバック & 異常な学習データの除去を行うシステムもフルスクラッチで構築してみた。

実際の使用例を先に示すと、以下のデータセットbankruptの顧客の年齢と負債に基づいて、各顧客が倒産するかどうかを分類するモデルを作りたいとき、

>>Select * From bankrupt
id age debt y
1 40 0 0
2 21 10 0
3 22 10 0
4 32 30 0
5 44 50 1
6 30 100 1
7 63 310 1
8 53 420 1
9 39 530 1
10 49 1000 0

下記のQLライクなクエリを用いることでロジスティクス回帰モデルを訓練することができる。

>>Logreg lrmodel id y 100 1 From Select * From bankrupt

ただ、Selectの結果に注目すると、debtが大きくなるにつれて倒産リスクが高まっているように見えるにもかかわらず、10番目のレコードのラベルが0であるがために、debtに対応する係数がマイナスになってしまい、学習済みモデルの精度が低くなっている。

このような場合に、クライアントがComplaintというクエリを発行し、「debtが100以上である顧客のラベルは1である可能性が高い」ということを教えると、サーバー側で自動的に問題のあるレコードを自動除去し、モデルを学習しなおすことができる。

>>Complaint comp Shouldbe 1 Remove 1 Against Logreg lrmodel id y 100 1 From Select * From bankrupt Where debt Geq 100

ちなみに、このコードは以下のレポジトリで管理している。

自作DBMS

基本的なアーキテクチャはbogoDBの実装を参考にして各モジュールを実装しつつ、様々な機能を追加するためにデータ構造やエンジンのアルゴリズムを改変していった。また、httpとjson周りはヘッダーオンリーのライブラリを利用し、シリアライズにはbogoDBと同様にprotobufを用いた。

今回実装したDBMSは以下のサブモジュール群から成り立っており、ここではその概要のみを説明する。具体的なコードは前述のGitHubレポジトリで全て管理している。

src/
├── compiler
├── core
├── storage
└── utils

compiler

このサブモジュールでは、SQLクエリのパースを行うプロセルを実装している。bogoDBをベースにしつつ、小数点のサポートや複数テーブル・列の選択、後に説明する機械学習用クエリのサポートなどを行った。構文解析はトップダウンアプローチを用いており、例えばInsertは以下のように実装されている。

# src/compiler/parse.h
struct Parser {
  InsertStmt *insertTableStmt() {
    expect(INTO);
    Token *tblName = expect(STRING);
    if (!tblName) {
      return nullptr;
    }
    expect(VALUES);
    expect(LPAREN);

    std::vector<Expr *> exprs;
    while (true) {
      exprs.push_back(eq());
      if (consume(RPAREN)) {
        break;
      }
      expect(COMMA);
    }

    InsertStmt *insertNode = new InsertStmt();
    insertNode->TableName = tblName->str;
    insertNode->Values = exprs;

    return insertNode;
  }
};

core

このサブモジュールでは、クライアント側のフロント部分とサーバ側のクエリ実行エンジンを実装した。一般的なDBMSと同様に、インタラクティブなクエリの実行・結果表示および、sourceコマンドを用いたファイル実行をサポートしている。サーバ側の実行エンジンはExecutorと呼ばれる構造体で定義されており、例えばInsertの実行は以下のように実装されている。

// Executor
struct Executor {...};

inline ResultSet *Executor::insertTable(InsertQuery *q, Transaction *tran) {
      ...
  if (!inTransaction) {
    tran = beginTransaction();
  }
  storage::Tuple *t = NewTuple(tran->Txid(), q->Values);
  std::pair<int, int> tid = storage->InsertTuple(q->table->Name, t);
  storage->InsertIndex(q->Index, t->data(0).toi(), tid);
  if (!inTransaction) {
    commitTransaction(tran);
  }
      ...
}

inline ResultSet *Executor::executeMain(Query *q, Plan *p, Transaction *tran) {
      ...
  if (auto insertQuery = dynamic_cast<InsertQuery *>(q)) {
    return insertTable(insertQuery, tran);
  }
      ...
}

storage

このサブモジュールでは、タプル、ページ、バッファプール、ストレージなどデータ格納のための機能を実装している。bogoDBでは一枚のページしか実装しておらず最大で32個の要素しか一つのテーブルに格納できなかったため、これを改良して複数ページにまたがるデータ管理も実装した。

Executorは基本的にストレージのAPIを用いてデータの読み取り・書き出しを行っており、ストレージがバッファプールもしくはストレージから適切なタプルやインデクスを取得する。

# src/storage/storage.h
class Storage {
public:
  void insertPage(const std::string &tableName, PageStatus pstatus) {
        ...
  }

  TID InsertTuple(const std::string &tablename, storage::Tuple *t) {
        ...
  }

  BPlusTreeMap<int, TID> *CreateIndex(const std::string &indexName) {
        ...
  }

  void InsertIndex(const std::string &indexName, int item, TID &tid) {
        ...
  }

  BPlusTreeMap<int, TID> *ReadIndex(const std::string &indexName) {
        ...
  }

  storage::Tuple *ReadTuple(const std::string &tableName, TID tid) {
        ...
  }

  void Terminate() {
        ...
  }
};
class BufferPool {
public:
  int FrontPgid(const std::string &tableName) {
        ...
  }

  int NewPgid(const std::string &tableName) {
        ...
  }

  Page *readPage(const std::string &tableName, uint64_t tid) {
        ...
  }

  std::pair<bool, std::pair<int, int>> appendTuple(const std::string &tableName,
                                                   storage::Tuple *t) {
        ...
  }

  std::pair<bool, Page *> putPage(const std::string &tableName, uint64_t pgid,
                                  Page *p) {
        ...
};
#src/storage/disk.h
class DiskManager {
public:
  Page *fetchPage(const std::string &dirPath, const std::string &tableName,
                  uint64_t pgid) {
        ...
  }

  void persist(const std::string &dirName, const std::string &tableName,
               uint64_t pgid, const Page *page) {
        ...
  }

  BPlusTreeMap<int, TID> *readIndex(const std::string &indexName) {
        ...
  }

  void writeIndex(const std::string &dirPath, const std::string &indexName,
                  BPlusTreeMap<int, TID> *tree) {
        ...
  }
};

utils

このサブモジュールでは、その他のデータ構造などを定義しており、LRUキャッシュやB+木を実装した。B+木を使うことで、効率よく範囲検索ができるようになった。

Rain (機械学習モデル向けインタラクティブデバッガ)

RainはSQLライクなクエリを用いてユーザーがインタラクティブに機械学習モデルのデバッグ・再学習を行うことを可能にするシステムである。具体的なプロセスは、(1 機械学習モデルの望ましくない挙動を見つけ、(2 ユーザーがそれを指摘する苦情 (complaint)を伝え、(3 その苦情を微分可能な形に変換した上で、(4 苦情をより小さくするように学習データから一部のデータを取り除く、というものである。苦情の例としては、冒頭の例のような「負債が100以上のレコードはラベル1として識別される方が望ましい」などの形が想定されているが、より具体的に「負債が100以上のレコードのうちラベル0として識別される件数を最小化したい」や「負債が100以上のレコードに対するラベル0の予測確率の合計を最小化したい」など何らかの集計処理を行うことが考えられている。

そして、Rainでは苦情に対する各データの影響度、つまりどのレコードが苦情を引き起こしているのかを、以下の式で定義している。ただし$x$は訓練データ、$N$は訓練データの総数、$\theta$はモデルのパラメータ、$\ell$は損失関数である。

苦情に対する各データの影響度 = Q'H^{-1}E
Q' = \frac{\partial 苦情}{\partial \theta}, \quad H = \frac{\partial^{2} \sum^{N}_{i=1} \ell(\theta, x_i)}{\partial^{2} \theta} \quad E = \{\frac{\partial \ell(\theta, x_i)}{\partial \theta}\}^{N}_{i=1}

つまり、$H$は学習データ全体における損失に対するパラメータの勾配、$E$は一つ一つのレコードにおけるパラメータの勾配である。$Q'$の具体的な式は苦情の形式ごとに異なるが、「負債が100以上のレコードに対するラベル0の予測確率の合計を最小化したい」という例であれば、

Q' = \frac{\partial \sum_{i \in 負債が100以上} P(y_{i} = 0  | x_{i}, \theta)}{\partial \theta}

と定式化することができる。このようにして各データの影響度を求め、影響度が大きいものから順に$k$個のデータを訓練データセットから取り除き、モデルを再学習する。

今回は外部ライブラリをなるべく用いずに実装したかったので、ロジスティクス回帰を一から実装して、SQLライクなクエリで呼び出せるようにした。Logregは指定されたデータに対してロジスティクス回帰を実行するクエリであり、勾配降下法を用いた最適化を行っているためそのイテレーション回数と学習率を指定することができる。このクエリの結果としてはモデルのパラメータ、AUC、および個々のデータに対する予測値がどのテーブルに格納されているかが帰ってくる。

>>Logreg `model_name` `primary_key_name` `target_column_name` `number_of_iteration` `learning_rate` From Select `primary_key_name`, `feature_name` From `table_name`
Trained Parameters:
 (0) : 5.845724
 (1) : -1.187462
 (2) : -1.274891
AUC: 0.840000
Predictions on the training data are stored at `prediction_on_training_data_lrmodel`

Rainを用いたデバックも同様にSQLライクなクエリで実現でき、以下のようにWhere conditionで指定されたデータに対する予測値がtarget_classに近づくようにnumber_of_removedだけ学習データから問題のあるレコードを取り除いて再学習を行う。

>>Complaint `complaint_name` Shouldbe `target_class` Remove `number_of_removed_records` Against Logreg ... Where `condition`
Fixed Parameters:
 (0) : -4.765492
 (1) : 8.747224
 (2) : 0.744146
AUC: 1.000000
Prediction on the fixed training data is stored at `prediction_on_training_data_comp_lrmodel`

より具体的な実行方法やサンプルデータはGitHubに置いてある。

6
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
3