LoginSignup
8
2

More than 1 year has passed since last update.

go-gorm/genの挙動を追いかけてみた

Posted at

はじめに

みなさん、こんにちは
Amebaの広告プロダクトチームでバックエンドエンジニアをやっている永井です。
この記事は、メディア事業部の広告横軸組織PTAアドベントカレンダー1日目の記事となります。

動機

仕事でDBにアクセスするアプリケーションをScalaからGoに書き換えるにあたって、テーブル定義を構造体からコツコツ作るのめんどくさいなと思い、自動でいい感じに生成してくれるツールないかなと検討した結果、genが便利そうだったので利用することにしました。
利用する分には、便利なのですが実際どのように自動生成してるのか気になったので調べてみました。

genとは

RDBのテーブル定義からGoの構造体を自動生成してくれるツールです。genで生成した構造体はgormで利用可能です。

挙動追跡

下記のコードを実行するだけでテーブル定義に沿ったGoの構造体をサクッと生成してくれます。

package main

import "gorm.io/gen"

func main() {
  g := gen.NewGenerator(gen.Config{
    OutPath: "../query",
    Mode: gen.WithDefaultQuery, // generate mode
  })

  // gormでDBに接続します。接続先の情報は適当なものです。
  db, _ := gorm.Open(mysql.Open("root:@(127.0.0.1:3306)/demo?charset=utf8mb4&parseTime=True&loc=Local"))
  g.UseDB(db) // reuse your gorm db
  g.GenerateAllTable()
  g.Execute()
}

流れとしては、下記のようになっています。今回は挙動を追跡するにあたって、UseDB以降のソースコードを読み解いていきます。

  1. Generator生成 NewGenerator
  2. 利用するDB指定 UseDB
  3. 全テーブルを自動生成対象として指定 GenerateAllTable
  4. 実行 Execute

UseDB

func (g *Generator) UseDB(db *gorm.DB) {
	if db != nil {
		g.db = db
	}
}

引数として受け取ったdbをdbフィールドに代入していますね。dbフィールドを経由してRDBにアクセスするんだろうなということが伺えます。
ちなみにGeneratorはこのような定義になっています。Configのフィールドとして定義されているようですね。

type Generator struct {
	Config

	Data   map[string]*genInfo                  //gen query data
	models map[string]*generate.QueryStructMeta //gen model data
}

Configの定義はこのようになっており、dbがありました。色々と設定項目があることが伺えます。今回は細かく触れません。

type Config struct {
	db *gorm.DB // db connection

	OutPath      string // query code path
	OutFile      string // query code file name, default: gen.go
	ModelPkgPath string // generated model code's package name
	WithUnitTest bool   // generate unit test for query code

	// generate model global configuration
	FieldNullable     bool // generate pointer when field is nullable
	FieldCoverable    bool // generate pointer when field has default value, to fix problem zero value cannot be assign: https://gorm.io/docs/create.html#Default-Values
	FieldSignable     bool // detect integer field's unsigned type, adjust generated data type
	FieldWithIndexTag bool // generate with gorm index tag
	FieldWithTypeTag  bool // generate with gorm column type tag

	Mode GenerateMode // generate mode

	queryPkgName   string // generated query code's package name
	modelPkgPath   string // model pkg path in target project
	dbNameOpts     []model.SchemaNameOpt
	importPkgPaths []string

	// name strategy for syncing table from db
	tableNameNS func(tableName string) (targetTableName string)
	modelNameNS func(tableName string) (modelName string)
	fileNameNS  func(tableName string) (fileName string)

	dataTypeMap    map[string]func(detailType string) (dataType string)
	fieldJSONTagNS func(columnName string) (tagContent string)
	fieldNewTagNS  func(columnName string) (tagContent string)

	modelOpts []ModelOpt
}

GenerateAllTable

func (g *Generator) GenerateAllTable(opts ...ModelOpt) (tableModels []interface{}) {
	tableList, err := g.db.Migrator().GetTables()
	if err != nil {
		panic(fmt.Errorf("get all tables fail: %w", err))
	}

	g.info(fmt.Sprintf("find %d table from db: %s", len(tableList), tableList))

	tableModels = make([]interface{}, len(tableList))
	for i, tableName := range tableList {
		tableModels[i] = g.GenerateModel(tableName, opts...)
	}
	return tableModels
}

この関数は下記のような流れになっています。それぞれ詳しく見ていきます。

  1. テーブル一覧取得 g.db.Migrator().GetTables()
  2. 各テーブルに対して GenerateModelを実行する

GetTables

各種DriverごとにMigratorが存在するようです。今回はmysqlのGetTablesを見ていきます。直接SQLが記述されていますね。information_schema.tablesというテーブルからテーブル一覧を取得していることがわかります。

func (m Migrator) GetTables() (tableList []string, err error) {
	err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).
		Scan(&tableList).Error
	return
}

GenerateModel

内部でGenerateModelAsを呼び出していますね。

func (g *Generator) GenerateModel(tableName string, opts ...ModelOpt) *generate.QueryStructMeta {
	return g.GenerateModelAs(tableName, g.db.Config.NamingStrategy.SchemaName(tableName), opts...)
}

GenerateModelAs

中身を見ていくと、テーブル名からメタ情報を取得し、メタ情報をgeneratormodelsマップに代入していますね。肝はGetQueryStructMeta関数のようです。

func (g *Generator) GenerateModelAs(tableName string, modelName string, opts ...ModelOpt) *generate.QueryStructMeta {
	meta, err := generate.GetQueryStructMeta(g.db, g.genModelConfig(tableName, modelName, opts))
	if err != nil {
		g.db.Logger.Error(context.Background(), "generate struct from table fail: %s", err)
		panic("generate struct fail")
	}
	if meta == nil {
		g.info(fmt.Sprintf("ignore table <%s>", tableName))
		return nil
	}
	g.models[meta.ModelStructName] = meta

	g.info(fmt.Sprintf("got %d columns from table <%s>", len(meta.Fields), meta.TableName))
	return meta
}

GetQueryStructMeta

getTableColumnsでカラムの一覧を取得し、QueryStructMetaを返していますね。メタ情報がたくさんありますね。getTableColumnsが大事そうです。

func GetQueryStructMeta(db *gorm.DB, conf *model.Config) (*QueryStructMeta, error) {
	if _, ok := db.Config.Dialector.(tests.DummyDialector); ok {
		return nil, fmt.Errorf("UseDB() is necessary to generate model struct [%s] from database table [%s]", conf.ModelName, conf.TableName)
	}

	conf = conf.Preprocess()
	tableName, structName, fileName := conf.GetNames()
	if tableName == "" {
		return nil, nil
	}
	if err := checkStructName(structName); err != nil {
		return nil, fmt.Errorf("model name %q is invalid: %w", structName, err)
	}

	columns, err := getTableColumns(db, conf.GetSchemaName(db), tableName, conf.FieldWithIndexTag)
	if err != nil {
		return nil, err
	}

	return (&QueryStructMeta{
		db:              db,
		Source:          model.Table,
		Generated:       true,
		FileName:        fileName,
		TableName:       tableName,
		ModelStructName: structName,
		QueryStructName: uncaptialize(structName),
		S:               strings.ToLower(structName[0:1]),
		StructInfo:      parser.Param{Type: structName, Package: conf.ModelPkg},
		ImportPkgPaths:  conf.ImportPkgPaths,
		Fields:          getFields(db, conf, columns),
	}).addMethodFromAddMethodOpt(conf.GetModelMethods()...), nil
}

getTableColumns

GetTableColumnsでカラムの一覧を取得し、GetTableIndexでインデックスの一覧も取得していますね。その後、カラムの情報にindexの情報を追加しているようです。GetTableColumnsの中身が気になります。

func getTableColumns(db *gorm.DB, schemaName string, tableName string, indexTag bool) (result []*model.Column, err error) {
	if db == nil {
		return nil, errors.New("gorm db is nil")
	}

	mt := getTableInfo(db)
	result, err = mt.GetTableColumns(schemaName, tableName)
	if err != nil {
		return nil, err
	}
	if !indexTag || len(result) == 0 {
		return result, nil
	}

	index, err := mt.GetTableIndex(schemaName, tableName)
	if err != nil { //ignore find index err
		db.Logger.Warn(context.Background(), "GetTableIndex for %s,err=%s", tableName, err.Error())
		return result, nil
	}
	if len(index) == 0 {
		return result, nil
	}

	im := model.GroupByColumn(index)
	for _, c := range result {
		c.Indexes = im[c.Name()]
	}
	return result, nil
}

GetTableColumns

Migratorが出てきましたね。もうすぐSQLが拝めそうです。ColumnTypesでカラムの一覧を取得しているようですね。

func (t *tableInfo) GetTableColumns(schemaName string, tableName string) (result []*model.Column, err error) {
	types, err := t.Migrator().ColumnTypes(tableName)
	if err != nil {
		return nil, err
	}
	for _, column := range types {
		result = append(result, &model.Column{ColumnType: column, TableName: tableName, UseScanType: t.Dialector.Name() != "mysql" && t.Dialector.Name() != "sqlite"})
	}
	return result, nil
}

ColumnTypes

読みたくなくなるコード量になってきましたが、ついにSQLに辿り着きました。大事なのはどのテーブルからどのような情報を取得しているかです。information_schema.columnsからtable_schematable_nameを指定しているようですね。そこから、カラムの名前column_nameやデフォルト値column_default、データタイプdata_typeなどなどを取得しているようです。その後、gormの構造体して組み立てて返却していますね。

func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
	columnTypes := make([]gorm.ColumnType, 0)
	err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
		var (
			currentDatabase, table = m.CurrentSchema(stmt, stmt.Table)
			columnTypeSQL          = "SELECT column_name, column_default, is_nullable = 'YES', data_type, character_maximum_length, column_type, column_key, extra, column_comment, numeric_precision, numeric_scale "
			rows, err              = m.DB.Session(&gorm.Session{}).Table(table).Limit(1).Rows()
		)

		if err != nil {
			return err
		}

		rawColumnTypes, err := rows.ColumnTypes()

		if err := rows.Close(); err != nil {
			return err
		}

		if !m.DisableDatetimePrecision {
			columnTypeSQL += ", datetime_precision "
		}
		columnTypeSQL += "FROM information_schema.columns WHERE table_schema = ? AND table_name = ? ORDER BY ORDINAL_POSITION"

		columns, rowErr := m.DB.Raw(columnTypeSQL, currentDatabase, table).Rows()
		if rowErr != nil {
			return rowErr
		}

		defer columns.Close()

		for columns.Next() {
			var (
				column            migrator.ColumnType
				datetimePrecision sql.NullInt64
				extraValue        sql.NullString
				columnKey         sql.NullString
				values            = []interface{}{
					&column.NameValue, &column.DefaultValueValue, &column.NullableValue, &column.DataTypeValue, &column.LengthValue, &column.ColumnTypeValue, &columnKey, &extraValue, &column.CommentValue, &column.DecimalSizeValue, &column.ScaleValue,
				}
			)

			if !m.DisableDatetimePrecision {
				values = append(values, &datetimePrecision)
			}

			if scanErr := columns.Scan(values...); scanErr != nil {
				return scanErr
			}

			column.PrimaryKeyValue = sql.NullBool{Bool: false, Valid: true}
			column.UniqueValue = sql.NullBool{Bool: false, Valid: true}
			switch columnKey.String {
			case "PRI":
				column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
			case "UNI":
				column.UniqueValue = sql.NullBool{Bool: true, Valid: true}
			}

			if strings.Contains(extraValue.String, "auto_increment") {
				column.AutoIncrementValue = sql.NullBool{Bool: true, Valid: true}
			}

			column.DefaultValueValue.String = strings.Trim(column.DefaultValueValue.String, "'")
			if m.Dialector.DontSupportNullAsDefaultValue {
				// rewrite mariadb default value like other version
				if column.DefaultValueValue.Valid && column.DefaultValueValue.String == "NULL" {
					column.DefaultValueValue.Valid = false
					column.DefaultValueValue.String = ""
				}
			}

			if datetimePrecision.Valid {
				column.DecimalSizeValue = datetimePrecision
			}

			for _, c := range rawColumnTypes {
				if c.Name() == column.NameValue.String {
					column.SQLColumnType = c
					break
				}
			}

			columnTypes = append(columnTypes, column)
		}

		return nil
	})

	return columnTypes, err
}

GenerateModelまとめ

長くなってきたので一旦まとめましょう。いまはGenerateModelを深掘り、カラムの情報をinformation_schema.columnsから取得し、メタ情報として組み立ててGeneratormodelsマップに代入していることがわかりました。

GenerateAllTableまとめ

さらに上の階層のGenerateAllTableについてもここでまとめましょう。GetTablesでテーブル一覧を取得し、各テーブルについてGenerateModelでカラム情報などのメタ情報を取得し、Generatormodelsマップに代入していることがわかりました。

Execute

さて、最後にExecuteです。generateModelFileを実行して構造体を生成し、generateQueryFileを実行してクエリ用のコードを生成しているようです。今回は構造体の生成に着目して、generateModelFileを見ていきます。

func (g *Generator) Execute() {
	g.info("Start generating code.")

	if err := g.generateModelFile(); err != nil {
		g.db.Logger.Error(context.Background(), "generate model struct fail: %s", err)
		panic("generate model struct fail")
	}

	if err := g.generateQueryFile(); err != nil {
		g.db.Logger.Error(context.Background(), "generate query code fail: %s", err)
		panic("generate query code fail")
	}

	g.info("Generate code done.")
}

generateModelFile

poolsというinternalなパッケージを使って並列処理を行なっていることがわかりますね。poolsの詳しい内容は追っていきませんが、内部でsync.WaitGroupを利用しているので、使いやすくラップしたパッケージとして作成しているようです。

本題の構造体生成処理を追っていきます。func(data *generator.QueryStructMeta)が肝のようですね。renderによって、テンプレートにメタデータを埋め込んでいるようです。その後、付随するメソッドも生成し、g.outputでファイル生成しているようですね。

func (g *Generator) generateModelFile() error {
	if len(g.models) == 0 {
		return nil
	}

	modelOutPath, err := g.getModelOutputPath()
	if err != nil {
		return err
	}

	if err = os.MkdirAll(modelOutPath, os.ModePerm); err != nil {
		return fmt.Errorf("create model pkg path(%s) fail: %s", modelOutPath, err)
	}

	errChan := make(chan error)
	pool := pools.NewPool(concurrent)
	for _, data := range g.models {
		if data == nil || !data.Generated {
			continue
		}
		pool.Wait()
		go func(data *generate.QueryStructMeta) {
			defer pool.Done()

			var buf bytes.Buffer
			err := render(tmpl.Model, &buf, data)
			if err != nil {
				errChan <- err
				return
			}

			for _, method := range data.ModelMethods {
				err = render(tmpl.ModelMethod, &buf, method)
				if err != nil {
					errChan <- err
					return
				}
			}

			modelFile := modelOutPath + data.FileName + ".gen.go"
			err = g.output(modelFile, buf.Bytes())
			if err != nil {
				errChan <- err
				return
			}

			g.info(fmt.Sprintf("generate model file(table <%s> -> {%s.%s}): %s", data.TableName, data.StructInfo.Package, data.StructInfo.Type, modelFile))
		}(data)
	}
	select {
	case err = <-errChan:
		return err
	case <-pool.AsyncWaitAll():
		g.fillModelPkgPath(modelOutPath)
	}
	return nil
}

render

text/templateパッケージを使ってテンプレートにメタ情報を埋め込んでいます。

func render(tmpl string, wr io.Writer, data interface{}) error {
	t, err := template.New(tmpl).Parse(tmpl)
	if err != nil {
		return err
	}
	return t.Execute(wr, data)
}

メタ情報はこのようになっています。.Filedsがカラム一覧なので、カラムそれぞれがフィールドして定義されることがわかりますね。他にも、テーブル名がconstとして定義されることもわかりますね。念の為outputも見ていきます。

const Model = NotEditMark + `
package {{.StructInfo.Package}}

import (
	"encoding/json"
	"time"

	"gorm.io/datatypes"
	"gorm.io/gorm"
	"gorm.io/gorm/schema"
	{{range .ImportPkgPaths}}{{.}} ` + "\n" + `{{end}}
)

{{if .TableName -}}const TableName{{.ModelStructName}} = "{{.TableName}}"{{- end}}

// {{.ModelStructName}} {{.StructComment}}
type {{.ModelStructName}} struct {
    {{range .Fields}}
	{{if .MultilineComment -}}
	/*
{{.ColumnComment}}
    */
	{{end -}}
    {{.Name}} {{.Type}} ` + "`{{.Tags}}` " +
	"{{if not .MultilineComment}}{{if .ColumnComment}}// {{.ColumnComment}}{{end}}{{end}}" +
	`{{end}}
}

`

output

imports.Processgoimportsのコードフォーマットを実行していることがわかります。最後にioutil.WriteFileでファイルとして書き出しているようですね。

func (g *Generator) output(fileName string, content []byte) error {
	result, err := imports.Process(fileName, content, nil)
	if err != nil {
		lines := strings.Split(string(content), "\n")
		errLine, _ := strconv.Atoi(strings.Split(err.Error(), ":")[1])
		startLine, endLine := errLine-5, errLine+5
		fmt.Println("Format fail:", errLine, err)
		if startLine < 0 {
			startLine = 0
		}
		if endLine > len(lines)-1 {
			endLine = len(lines) - 1
		}
		for i := startLine; i <= endLine; i++ {
			fmt.Println(i, lines[i])
		}
		return fmt.Errorf("cannot format file: %w", err)
	}
	return ioutil.WriteFile(fileName, result, 0640)
}

生成処理まとめ

information_schema.talbesinformation_schema.columnsからテーブルとカラムの情報を取得し、テンプレートに取得した情報を埋め込むことで構造体ファイルを生成していることがわかりました。

おわりに

最後までお読みいただきありがとうございます。ふと気になってgo-gorm/genの詳細を追いかけて見ましたが、どのようなSQLを発行して必要な情報を取得し、構造体ファイルを生成しているか把握できて良い経験になりました。それだけでなく、OSSのコードを読むのは色々なテクニックが学べてとても刺激的ですね。理解が深まりました。

8
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
2