はじめに
みなさん、こんにちは
Amebaの広告プロダクトチームでバックエンドエンジニアをやっている永井です。
この記事は、メディア事業部の広告横軸組織PTAのアドベントカレンダー1日目の記事となります。
動機
仕事でDBにアクセスするアプリケーションをScalaからGoに書き換えるにあたって、テーブル定義を構造体からコツコツ作るのめんどくさいなと思い、自動でいい感じに生成してくれるツールないかなと検討した結果、genが便利そうだったので利用することにしました。
利用する分には、便利なのですが実際どのように自動生成してるのか気になったので調べてみました。
genとは
RDBのテーブル定義からGoの構造体を自動生成してくれるツールです。genで生成した構造体はgormで利用可能です。
挙動追跡
下記のコードを実行するだけでテーブル定義に沿ったGoの構造体をサクッと生成してくれます。
package main
import "gorm.io/gen"
func main() {
g := gen.NewGenerator(gen.Config{
OutPath: "../query",
Mode: gen.WithDefaultQuery, // generate mode
})
// gormでDBに接続します。接続先の情報は適当なものです。
db, _ := gorm.Open(mysql.Open("root:@(127.0.0.1:3306)/demo?charset=utf8mb4&parseTime=True&loc=Local"))
g.UseDB(db) // reuse your gorm db
g.GenerateAllTable()
g.Execute()
}
流れとしては、下記のようになっています。今回は挙動を追跡するにあたって、UseDB
以降のソースコードを読み解いていきます。
- Generator生成
NewGenerator
- 利用するDB指定
UseDB
- 全テーブルを自動生成対象として指定
GenerateAllTable
- 実行
Execute
UseDB
func (g *Generator) UseDB(db *gorm.DB) {
if db != nil {
g.db = db
}
}
引数として受け取ったdbをdbフィールドに代入していますね。dbフィールドを経由してRDBにアクセスするんだろうなということが伺えます。
ちなみにGenerator
はこのような定義になっています。Config
のフィールドとして定義されているようですね。
type Generator struct {
Config
Data map[string]*genInfo //gen query data
models map[string]*generate.QueryStructMeta //gen model data
}
Config
の定義はこのようになっており、db
がありました。色々と設定項目があることが伺えます。今回は細かく触れません。
type Config struct {
db *gorm.DB // db connection
OutPath string // query code path
OutFile string // query code file name, default: gen.go
ModelPkgPath string // generated model code's package name
WithUnitTest bool // generate unit test for query code
// generate model global configuration
FieldNullable bool // generate pointer when field is nullable
FieldCoverable bool // generate pointer when field has default value, to fix problem zero value cannot be assign: https://gorm.io/docs/create.html#Default-Values
FieldSignable bool // detect integer field's unsigned type, adjust generated data type
FieldWithIndexTag bool // generate with gorm index tag
FieldWithTypeTag bool // generate with gorm column type tag
Mode GenerateMode // generate mode
queryPkgName string // generated query code's package name
modelPkgPath string // model pkg path in target project
dbNameOpts []model.SchemaNameOpt
importPkgPaths []string
// name strategy for syncing table from db
tableNameNS func(tableName string) (targetTableName string)
modelNameNS func(tableName string) (modelName string)
fileNameNS func(tableName string) (fileName string)
dataTypeMap map[string]func(detailType string) (dataType string)
fieldJSONTagNS func(columnName string) (tagContent string)
fieldNewTagNS func(columnName string) (tagContent string)
modelOpts []ModelOpt
}
GenerateAllTable
func (g *Generator) GenerateAllTable(opts ...ModelOpt) (tableModels []interface{}) {
tableList, err := g.db.Migrator().GetTables()
if err != nil {
panic(fmt.Errorf("get all tables fail: %w", err))
}
g.info(fmt.Sprintf("find %d table from db: %s", len(tableList), tableList))
tableModels = make([]interface{}, len(tableList))
for i, tableName := range tableList {
tableModels[i] = g.GenerateModel(tableName, opts...)
}
return tableModels
}
この関数は下記のような流れになっています。それぞれ詳しく見ていきます。
- テーブル一覧取得
g.db.Migrator().GetTables()
- 各テーブルに対して
GenerateModel
を実行する
GetTables
各種DriverごとにMigrator
が存在するようです。今回はmysqlのGetTables
を見ていきます。直接SQLが記述されていますね。information_schema.tables
というテーブルからテーブル一覧を取得していることがわかります。
func (m Migrator) GetTables() (tableList []string, err error) {
err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).
Scan(&tableList).Error
return
}
GenerateModel
内部でGenerateModelAs
を呼び出していますね。
func (g *Generator) GenerateModel(tableName string, opts ...ModelOpt) *generate.QueryStructMeta {
return g.GenerateModelAs(tableName, g.db.Config.NamingStrategy.SchemaName(tableName), opts...)
}
GenerateModelAs
中身を見ていくと、テーブル名からメタ情報を取得し、メタ情報をgenerator
のmodels
マップに代入していますね。肝はGetQueryStructMeta
関数のようです。
func (g *Generator) GenerateModelAs(tableName string, modelName string, opts ...ModelOpt) *generate.QueryStructMeta {
meta, err := generate.GetQueryStructMeta(g.db, g.genModelConfig(tableName, modelName, opts))
if err != nil {
g.db.Logger.Error(context.Background(), "generate struct from table fail: %s", err)
panic("generate struct fail")
}
if meta == nil {
g.info(fmt.Sprintf("ignore table <%s>", tableName))
return nil
}
g.models[meta.ModelStructName] = meta
g.info(fmt.Sprintf("got %d columns from table <%s>", len(meta.Fields), meta.TableName))
return meta
}
GetQueryStructMeta
getTableColumns
でカラムの一覧を取得し、QueryStructMeta
を返していますね。メタ情報がたくさんありますね。getTableColumns
が大事そうです。
func GetQueryStructMeta(db *gorm.DB, conf *model.Config) (*QueryStructMeta, error) {
if _, ok := db.Config.Dialector.(tests.DummyDialector); ok {
return nil, fmt.Errorf("UseDB() is necessary to generate model struct [%s] from database table [%s]", conf.ModelName, conf.TableName)
}
conf = conf.Preprocess()
tableName, structName, fileName := conf.GetNames()
if tableName == "" {
return nil, nil
}
if err := checkStructName(structName); err != nil {
return nil, fmt.Errorf("model name %q is invalid: %w", structName, err)
}
columns, err := getTableColumns(db, conf.GetSchemaName(db), tableName, conf.FieldWithIndexTag)
if err != nil {
return nil, err
}
return (&QueryStructMeta{
db: db,
Source: model.Table,
Generated: true,
FileName: fileName,
TableName: tableName,
ModelStructName: structName,
QueryStructName: uncaptialize(structName),
S: strings.ToLower(structName[0:1]),
StructInfo: parser.Param{Type: structName, Package: conf.ModelPkg},
ImportPkgPaths: conf.ImportPkgPaths,
Fields: getFields(db, conf, columns),
}).addMethodFromAddMethodOpt(conf.GetModelMethods()...), nil
}
getTableColumns
GetTableColumns
でカラムの一覧を取得し、GetTableIndex
でインデックスの一覧も取得していますね。その後、カラムの情報にindexの情報を追加しているようです。GetTableColumns
の中身が気になります。
func getTableColumns(db *gorm.DB, schemaName string, tableName string, indexTag bool) (result []*model.Column, err error) {
if db == nil {
return nil, errors.New("gorm db is nil")
}
mt := getTableInfo(db)
result, err = mt.GetTableColumns(schemaName, tableName)
if err != nil {
return nil, err
}
if !indexTag || len(result) == 0 {
return result, nil
}
index, err := mt.GetTableIndex(schemaName, tableName)
if err != nil { //ignore find index err
db.Logger.Warn(context.Background(), "GetTableIndex for %s,err=%s", tableName, err.Error())
return result, nil
}
if len(index) == 0 {
return result, nil
}
im := model.GroupByColumn(index)
for _, c := range result {
c.Indexes = im[c.Name()]
}
return result, nil
}
GetTableColumns
Migrator
が出てきましたね。もうすぐSQLが拝めそうです。ColumnTypes
でカラムの一覧を取得しているようですね。
func (t *tableInfo) GetTableColumns(schemaName string, tableName string) (result []*model.Column, err error) {
types, err := t.Migrator().ColumnTypes(tableName)
if err != nil {
return nil, err
}
for _, column := range types {
result = append(result, &model.Column{ColumnType: column, TableName: tableName, UseScanType: t.Dialector.Name() != "mysql" && t.Dialector.Name() != "sqlite"})
}
return result, nil
}
ColumnTypes
読みたくなくなるコード量になってきましたが、ついにSQLに辿り着きました。大事なのはどのテーブルからどのような情報を取得しているかです。information_schema.columns
からtable_schema
とtable_name
を指定しているようですね。そこから、カラムの名前column_name
やデフォルト値column_default
、データタイプdata_type
などなどを取得しているようです。その後、gormの構造体して組み立てて返却していますね。
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
columnTypes := make([]gorm.ColumnType, 0)
err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
var (
currentDatabase, table = m.CurrentSchema(stmt, stmt.Table)
columnTypeSQL = "SELECT column_name, column_default, is_nullable = 'YES', data_type, character_maximum_length, column_type, column_key, extra, column_comment, numeric_precision, numeric_scale "
rows, err = m.DB.Session(&gorm.Session{}).Table(table).Limit(1).Rows()
)
if err != nil {
return err
}
rawColumnTypes, err := rows.ColumnTypes()
if err := rows.Close(); err != nil {
return err
}
if !m.DisableDatetimePrecision {
columnTypeSQL += ", datetime_precision "
}
columnTypeSQL += "FROM information_schema.columns WHERE table_schema = ? AND table_name = ? ORDER BY ORDINAL_POSITION"
columns, rowErr := m.DB.Raw(columnTypeSQL, currentDatabase, table).Rows()
if rowErr != nil {
return rowErr
}
defer columns.Close()
for columns.Next() {
var (
column migrator.ColumnType
datetimePrecision sql.NullInt64
extraValue sql.NullString
columnKey sql.NullString
values = []interface{}{
&column.NameValue, &column.DefaultValueValue, &column.NullableValue, &column.DataTypeValue, &column.LengthValue, &column.ColumnTypeValue, &columnKey, &extraValue, &column.CommentValue, &column.DecimalSizeValue, &column.ScaleValue,
}
)
if !m.DisableDatetimePrecision {
values = append(values, &datetimePrecision)
}
if scanErr := columns.Scan(values...); scanErr != nil {
return scanErr
}
column.PrimaryKeyValue = sql.NullBool{Bool: false, Valid: true}
column.UniqueValue = sql.NullBool{Bool: false, Valid: true}
switch columnKey.String {
case "PRI":
column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
case "UNI":
column.UniqueValue = sql.NullBool{Bool: true, Valid: true}
}
if strings.Contains(extraValue.String, "auto_increment") {
column.AutoIncrementValue = sql.NullBool{Bool: true, Valid: true}
}
column.DefaultValueValue.String = strings.Trim(column.DefaultValueValue.String, "'")
if m.Dialector.DontSupportNullAsDefaultValue {
// rewrite mariadb default value like other version
if column.DefaultValueValue.Valid && column.DefaultValueValue.String == "NULL" {
column.DefaultValueValue.Valid = false
column.DefaultValueValue.String = ""
}
}
if datetimePrecision.Valid {
column.DecimalSizeValue = datetimePrecision
}
for _, c := range rawColumnTypes {
if c.Name() == column.NameValue.String {
column.SQLColumnType = c
break
}
}
columnTypes = append(columnTypes, column)
}
return nil
})
return columnTypes, err
}
GenerateModelまとめ
長くなってきたので一旦まとめましょう。いまはGenerateModel
を深掘り、カラムの情報をinformation_schema.columns
から取得し、メタ情報として組み立ててGenerator
のmodels
マップに代入していることがわかりました。
GenerateAllTableまとめ
さらに上の階層のGenerateAllTable
についてもここでまとめましょう。GetTables
でテーブル一覧を取得し、各テーブルについてGenerateModel
でカラム情報などのメタ情報を取得し、Generator
のmodels
マップに代入していることがわかりました。
Execute
さて、最後にExecute
です。generateModelFile
を実行して構造体を生成し、generateQueryFile
を実行してクエリ用のコードを生成しているようです。今回は構造体の生成に着目して、generateModelFile
を見ていきます。
func (g *Generator) Execute() {
g.info("Start generating code.")
if err := g.generateModelFile(); err != nil {
g.db.Logger.Error(context.Background(), "generate model struct fail: %s", err)
panic("generate model struct fail")
}
if err := g.generateQueryFile(); err != nil {
g.db.Logger.Error(context.Background(), "generate query code fail: %s", err)
panic("generate query code fail")
}
g.info("Generate code done.")
}
generateModelFile
pools
というinternal
なパッケージを使って並列処理を行なっていることがわかりますね。pools
の詳しい内容は追っていきませんが、内部でsync.WaitGroup
を利用しているので、使いやすくラップしたパッケージとして作成しているようです。
本題の構造体生成処理を追っていきます。func(data *generator.QueryStructMeta)
が肝のようですね。render
によって、テンプレートにメタデータを埋め込んでいるようです。その後、付随するメソッドも生成し、g.output
でファイル生成しているようですね。
func (g *Generator) generateModelFile() error {
if len(g.models) == 0 {
return nil
}
modelOutPath, err := g.getModelOutputPath()
if err != nil {
return err
}
if err = os.MkdirAll(modelOutPath, os.ModePerm); err != nil {
return fmt.Errorf("create model pkg path(%s) fail: %s", modelOutPath, err)
}
errChan := make(chan error)
pool := pools.NewPool(concurrent)
for _, data := range g.models {
if data == nil || !data.Generated {
continue
}
pool.Wait()
go func(data *generate.QueryStructMeta) {
defer pool.Done()
var buf bytes.Buffer
err := render(tmpl.Model, &buf, data)
if err != nil {
errChan <- err
return
}
for _, method := range data.ModelMethods {
err = render(tmpl.ModelMethod, &buf, method)
if err != nil {
errChan <- err
return
}
}
modelFile := modelOutPath + data.FileName + ".gen.go"
err = g.output(modelFile, buf.Bytes())
if err != nil {
errChan <- err
return
}
g.info(fmt.Sprintf("generate model file(table <%s> -> {%s.%s}): %s", data.TableName, data.StructInfo.Package, data.StructInfo.Type, modelFile))
}(data)
}
select {
case err = <-errChan:
return err
case <-pool.AsyncWaitAll():
g.fillModelPkgPath(modelOutPath)
}
return nil
}
render
text/template
パッケージを使ってテンプレートにメタ情報を埋め込んでいます。
func render(tmpl string, wr io.Writer, data interface{}) error {
t, err := template.New(tmpl).Parse(tmpl)
if err != nil {
return err
}
return t.Execute(wr, data)
}
メタ情報はこのようになっています。.Fileds
がカラム一覧なので、カラムそれぞれがフィールドして定義されることがわかりますね。他にも、テーブル名がconstとして定義されることもわかりますね。念の為output
も見ていきます。
const Model = NotEditMark + `
package {{.StructInfo.Package}}
import (
"encoding/json"
"time"
"gorm.io/datatypes"
"gorm.io/gorm"
"gorm.io/gorm/schema"
{{range .ImportPkgPaths}}{{.}} ` + "\n" + `{{end}}
)
{{if .TableName -}}const TableName{{.ModelStructName}} = "{{.TableName}}"{{- end}}
// {{.ModelStructName}} {{.StructComment}}
type {{.ModelStructName}} struct {
{{range .Fields}}
{{if .MultilineComment -}}
/*
{{.ColumnComment}}
*/
{{end -}}
{{.Name}} {{.Type}} ` + "`{{.Tags}}` " +
"{{if not .MultilineComment}}{{if .ColumnComment}}// {{.ColumnComment}}{{end}}{{end}}" +
`{{end}}
}
`
output
imports.Process
でgoimports
のコードフォーマットを実行していることがわかります。最後にioutil.WriteFile
でファイルとして書き出しているようですね。
func (g *Generator) output(fileName string, content []byte) error {
result, err := imports.Process(fileName, content, nil)
if err != nil {
lines := strings.Split(string(content), "\n")
errLine, _ := strconv.Atoi(strings.Split(err.Error(), ":")[1])
startLine, endLine := errLine-5, errLine+5
fmt.Println("Format fail:", errLine, err)
if startLine < 0 {
startLine = 0
}
if endLine > len(lines)-1 {
endLine = len(lines) - 1
}
for i := startLine; i <= endLine; i++ {
fmt.Println(i, lines[i])
}
return fmt.Errorf("cannot format file: %w", err)
}
return ioutil.WriteFile(fileName, result, 0640)
}
生成処理まとめ
information_schema.talbes
とinformation_schema.columns
からテーブルとカラムの情報を取得し、テンプレートに取得した情報を埋め込むことで構造体ファイルを生成していることがわかりました。
おわりに
最後までお読みいただきありがとうございます。ふと気になってgo-gorm/gen
の詳細を追いかけて見ましたが、どのようなSQLを発行して必要な情報を取得し、構造体ファイルを生成しているか把握できて良い経験になりました。それだけでなく、OSSのコードを読むのは色々なテクニックが学べてとても刺激的ですね。理解が深まりました。