5
2

go-mysql-serverを使ってテンポラリなDBを立ち上げてテストを行う

Last updated at Posted at 2024-08-11

最近、新しいプロジェクトで、久しぶりにGoを書いています。

DBを使ったコードのテストをしたいという要求が自分の中にあってどうしようかなと考えたところ、go-mysql-serverを使ってみることにしました。
dockertestやtestcontainersとかもあるよう(参考: 紹介記事)ですが、こっちのが速いかなというただの印象で、使うことにしました。後々dockertestやtestcontainersも使えるようにしても良いかなとは思ってはいます。

なお、本当のMySQLではないので、制限はあります。
制限の簡単な訳と使ってみての追加(最後の2つ)を書いておきます。

  • DDLやDML(CREATE TABLE, INSERT, 等) はthreadsafeではない
  • transactionがない
  • INDEXがあっても速くはならない
  • 使えない関数がある(参考)
    • 関数がバグっているかもしれない(遭遇したのは、STR_TO_DATEが、%Y%m%dを解釈しない(issue))
  • 1つのサーバに2つのDBは作れない(必要なら別でもう一つ作る必要がある)

実際のテストコード

最初にどんなテストコードになるかを紹介します。あくまでサンプルコードなので、DBを使ってるところは超適当ですが。

package sample

import (
	"context"
	"database/sql"
	"log"
	"test/test_db"
	"testing"

	"github.com/stretchr/testify/assert"
)

func Hoge(db *sql.DB, ctx context.Context, hoge_id int) int {
	sql := "select count(*) as num from hoge_item join hoge on hoge.id = hoge_id where hoge.id = ?;"
	rows, err := db.QueryContext(ctx, sql, hoge_id)
	if err != nil {
		log.Fatal(err)
	}
	num := 0

	if rows.Next() {
		rows.Scan(&num)
	}
	return num
}

// 上の関数のテスト
func TestHoge(t *testing.T) {
	testSet := &test_db.TestDataSet{
		Tables: test_db.Tables("hoge", "hoge_item"),
		SQLs: test_db.SQLs(
			`
			INSERT INTO hoge(id, name)
			VALUES(1, 'hoge1'),
			       (2, 'hoge2')
			`,
			`
			INSERT INTO hoge_item
			      (id, hoge_id, item)
			VALUES(1, 1, 'hoge_item1'),
			      (2, 1, 'hoge_item2'),
			      (3, 1, 'hoge_item3'),
			      (4, 2, 'hoge_item4')
			 `,
		),
	}
	db, ctx, err := test_db.SetupTestDB(testSet)
	if err != nil {
		log.Fatal(err)
	}
	defer db.CleanupDB()

	num := Hoge(db.DB, ctx, 1)
	assert.Equal(t, 3, num)

	num = Hoge(db.DB, ctx, 2)
	assert.Equal(t, 1, num)

	num = Hoge(db.DB, ctx, 3)
	assert.Equal(t, 0, num)
}

ざっくり説明すると、test_db.TestDataSetのところで、テーブルとテーブルに対するデータを用意するための構造体を作っています。

  1. hoge, hoge_itemテーブルにたいして
  2. SQLsに渡した2つのSQLを実行する

という指示の入った構造体です。見たままですね(SQLからテーブルわかるやんという話はありますが、まぁ)。

それを、test_db.SetupTestDB に渡すことで、テンポラリなDBが返ってきます。
後は、それをテストしたい関数Hogeに渡して、結果をassertで比較しているだけです。

SQL部分はtest-data/sql/以下に.sqlな名前のファイルを作って、
以下のように複数まとめて書くことがもできます。--(改行 + "--" +改行)が区切りです。

INSERT INTO hoge(id, name)
VALUES(1, 'hoge1'),
       (2, 'hoge2');
--
INSERT INTO hoge_item
      (id, hoge_id, item)
VALUES(1, 1, 'hoge_item1'),
      (2, 1, 'hoge_item2'),
      (3, 1, 'hoge_item3'),
      (4, 1, 'hoge_item4');

以下、どのように実現しているかを書きます。

Dbmateの利用

が、その前に、Dbmateという migrationツールを利用していて、これも少し関係しています。

ORMapperにsqlcを使っているのですが、sqlcのドキュメントでmigrationツールとして紹介されていました(sqlcもなかなか面白いので、また記事を書くかもしれません)。

Dbmateは、migrationディレクトリにSQLを書いて、それが順番に実行されますが、最後に schemaのdumpファイルが生成されます(migration がどこまでされたかはテーブルで管理されているので、その情報も含まれています)。一から空のDBを作りたい時に便利ですね。

で、お察しかもしれませんが、最初にテストコード内でテーブル名をしていましたが、そのテーブル名をdumpファイルから探して、CREATE TABLE を抜き出して実行しています。

今回のテストは実際のプロジェクトのものではないので、以下のschemaを db/schema.sql に用意しました。

CREATE TABLE `hoge` (
  id int not null,
  name varchar(10),
  primary key (id)
);

CREATE TABLE `hoge_item` (
  id int not null,
  hoge_id int not null,
  item varchar(10),
  FOREIGN KEY (hoge_id) REFERENCES hoge(id),
  primary key (id)
);

なお、Dbmateですが、ググると、go get でのinstallを紹介している記事もありますが、それだと古いバージョンが入るので、公式のインストール方法を参照しましょう。

go-mysql-serverでテストする

ようやく本題です。といって、コードを紹介するだけですが...。コメントをいつもより多めに書いてます。
なお、以下の前提のもとのコードですので、ご注意ください(ハードコードされているので)。

  • CWD という環境変数にproject rootのパスがある
  • DBmateのschema dumpは、db/schema.sql にある
  • テストに使う sqlファイルは、test-data/sql/ 以下にある
package test_db

import (
	"bufio"
	"context"
	"database/sql"
	"fmt"
	"log"
	"net"
	"os"
	"strings"
	"time"

	sqle "github.com/dolthub/go-mysql-server"
	"github.com/dolthub/go-mysql-server/memory"
	"github.com/dolthub/go-mysql-server/server"
	_ "github.com/go-sql-driver/mysql"

	"github.com/pkg/errors"
)

// 作成するテーブル
type TargetTables []string

// 実行するSQL
type ExecSQLs []string

// 上記をセットにしている
type TestDataSet struct {
	Tables TargetTables
	SQLs   ExecSQLs
}

// テスト用の一時データベース
type TestDB struct {
	Cancel     context.CancelFunc
	SocketFile string
	Listener   net.Listener
	DB         *sql.DB
}

// 特に意味はないですね
func Tables(tables ...string) TargetTables {
	return TargetTables(tables)
}

// sql もしくは ファイル名を受け取って、ExecSQLsにして返す
func SQLs(sqls ...string) ExecSQLs {
	var sqlSentences []string
	cwd := os.Getenv("CWD")
	for _, sql := range sqls {
		l := len(sql)
        // .sql で終わっている場合は、SQLが入ったファイルとみなす
		if l > 4 && sql[l-4:] == ".sql" {
			dat, err := os.ReadFile(cwd + "/test-data/sql/" + sql)
			if err != nil {
				log.Fatal(err)
			}
			sqlString := string(dat)
			sqlStrings := strings.Split(sqlString, "\n--\n")
			sqlSentences = append(sqlSentences, sqlStrings...)
		} else {
			sqlSentences = append(sqlSentences, sql)
		}
	}
	return ExecSQLs(sqlSentences)
}

// テンポラリなDBのセットアップ
func setupTestDB() (*TestDB, error) {
	tdb := &TestDB{}
	// DBの名前。何でも良い
	dbName := "test_database"

	// インメモリーなSQLエンジンの作成
	db := memory.NewDatabase(dbName)
	db.BaseDatabase.EnablePrimaryKeyIndexes()
	pro := memory.NewDBProvider(db)
	engine := sqle.NewDefault(pro)

	// rootアカウントの追加
	engine.Analyzer.Catalog.MySQLDb.AddRootAccount()

	// Unixファイルソケットの作成
	tdb.SocketFile = fmt.Sprintf("/tmp/testdb_%d.sock", time.Now().UnixNano())

	// 作ったソケットでListenerを作る
	Listener, err := net.Listen("unix", tdb.SocketFile)
	if err != nil {
		return nil, errors.Wrap(err, "failed to start Unix socket listener")
	}

	// DBの設定でSocketを使うようにする
	config := server.Config{
		Socket:   tdb.SocketFile,
		Listener: Listener,
	}

	// サーバを作成する
	s, err := server.NewServer(config, engine, memory.NewSessionBuilder(pro), nil)
	if err != nil {
		return nil, err
	}

	// サーバを動かす
	go func() {
		if err := s.Start(); err != nil {
			log.Fatalf("failed to start server: %v", err)
		}
	}()

	// 必要なら起動するまでちょっと待つ(なくても大丈夫そうなのでコメント。必要にせよもっと短くて良い)
	// time.Sleep(1 * time.Second)

	// サーバに接続するDSNを作る
	dsn := fmt.Sprintf("root@unix(%s)/"+dbName, tdb.SocketFile)

	// サーバに接続
	tdb.DB, err = sql.Open("mysql", dsn)
	if err != nil {
		tdb.Cancel()
		Listener.Close()
		return nil, err
	}

	return tdb, nil
}

// 後処理をまとめて行う
func (tdb *TestDB) CleanupDB() {
	tdb.DB.Close()
	os.Remove(tdb.SocketFile)
}

// TestDataSetを渡して、テンポラリのDBとコンテキストを返す
func SetupTestDB(testSet *TestDataSet) (*TestDB, context.Context, error) {
	tdb, err := setupTestDB()
	if err != nil {
		log.Fatal(err)
	}
	ctx, err := setupDBSchema(tdb.DB, testSet.Tables, testSet.SQLs...)
	if err != nil {
		log.Fatal(err)
	}
	return tdb, ctx, err
}

// テーブルの作成とデータ準備(SQLの実行)
func setupDBSchema(db *sql.DB, targetTables TargetTables, insertSQL ...string) (context.Context, error) {
	pwd := os.Getenv("CWD")
	file := pwd + "/db/schema.sql"
	if targetTables != nil {
		dumpContent, err := getArrangedDumpFile(file, targetTables)
		if err != nil {
			return nil, err
		}
		if len(targetTables) != len(dumpContent) {
			msg := fmt.Sprintf("nubmer of tables and number of create SQL doesn't match: %#v %#v", targetTables, dumpContent)
			return nil, errors.New(msg)
		}

		if len(insertSQL) > 0 {
			dumpContent = append(dumpContent, insertSQL...)
		}
		db.Exec("SET GLOBAL FOREIGN_KEY_CHECKS=0;")
		for _, sql := range dumpContent {
			_, err = db.Query(sql)
			if err != nil {
				return nil, errors.New(err.Error() + ":" + sql)
			}
		}
		db.Exec("SET GLOBAL FOREIGN_KEY_CHECKS=1;")
	}

	return context.Background(), nil
}

// schema の dump から CREATE TABLE を抜き出す(By ChatGPT)
func getArrangedDumpFile(filePath string, targetTables []string) ([]string, error) {

	file, err := os.Open(filePath)
	if err != nil {
		return nil, errors.Wrapf(err, "Error opening file")
	}
	defer file.Close()

	scanner := bufio.NewScanner(file)
	var createStatement string = ""
	createStatementsMap := make(map[string]string, 0)
	inCreateStatement := false
	currentTargetTable := ""

	for scanner.Scan() {
		line := scanner.Text()
		if strings.HasPrefix(line, "LOCK TABLES") {
			break
		}
		if inCreateStatement {
			if strings.HasSuffix(line, ";") {
				createStatement += line + "\n"
				createStatementsMap[currentTargetTable] = createStatement
				createStatement = ""
				currentTargetTable = ""
				inCreateStatement = false
			} else {
				createStatement += line + "\n"
			}

		} else if !strings.HasPrefix(line, "/*") {
			if len(targetTables) > 0 {
				for _, table := range targetTables {
					if strings.HasPrefix(line, "CREATE TABLE `"+table+"`") {
						createStatement += line + "\n"
						inCreateStatement = true
						currentTargetTable = table
						break
					}
				}
			} else if strings.HasPrefix(line, "CREATE TABLE `") {
				createStatement += line + "\n"
				inCreateStatement = true
			}
		}
	}
	// fmt.Println(createStatement)
	if err := scanner.Err(); err != nil {
		fmt.Println("Error reading file:", err)
	}
	createStatements := make([]string, 0)
	if len(targetTables) > 0 {
		for _, table := range targetTables {
			if sql, ok := createStatementsMap[table]; ok {
				createStatements = append(createStatements, sql)
			} else {
				fmt.Println("no SQL for " + table)

			}
		}
	}

	return createStatements, nil
}

終わり

こんな感じでDBを使ったテストができるようになりました。まだ使い始めではありますが、わりかしうまくいっているように思います。
実際のプロジェクトでは、2つのDBを同時に扱う必要があるので、それも可能なようにはしていますが、今回のコードは単純化のために、DBを1つだけ使うコードにしました。

当初、go-msyql-serverのサイトを見ると、コードでtable作ったりしていて、めんどくさいなーと思っていたんですが、よく考えたらサーバなんだから、直接SQL実行すりゃいいやんと思い、こうなりました。

テストデータをつくるSQLをファイルにしておけば、他のテストでも再利用できますし、実際のDBにテストデータ入れる時にも役に立つかと思いますので、そういうのも含めて良いのではないかなと思っています。

5
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
2