Go と reflect と generate

  • 68
    Like
  • 0
    Comment
More than 1 year has passed since last update.

この記事は Go Advent Calendar 2015 5 日目の記事です。

はじめに

Go にはジェネリクスがありませんが、任意の型に対して共通の処理を提供したいことがあります。
例えば ORM ライブラリで User テーブルと Blog テーブルを struct で表す場合を考えてみます。

type User struct {
    Id    int64
    Name  string
    Email string
}

type Blog struct {
    Id    int64
    Title string
    Body  string
}

この 2 つの struct に対して共通の処理、例えば id で検索して結果を返す FindById というメソッドを提供したいとします。

Go は静的型付き言語なので、複数の型を引数に取ったり、複数の型を返すといったメソッドを定義することはできません。
よって、それぞれの型に対してメソッドを定義することになります。

package main

var db = func() *sql.DB {
    db, err := sql.Open("mysql", "hoge")
    if err != nil {
        panic(err)
    }
    return db
}()

func (u User) FindById(id int64) (*User, error) {
    return &u, db.QueryRow("SELECT Id, Name, Email FROM User WHERE Id = ?", id).Scan(&u.Id, &u.Name, &u.Email)
}

func (b Blog) FindById(id int64) (*Blog, error) {
    return &b, db.QueryRow("SELECT Id, Title, Body FROM Blog WHERE Id = ?", id).Scan(&b.Id, &b.Title, &b.Body)
}

func main() {
    u, err := User{}.FindById(1)
    if err != nil {
        panic(err)
    }
    fmt.Printf("%#v\n", u)
    b, err := Blog{}.FindById(1)
    if err != nil {
        panic(err)
    }
    fmt.Printf("%#v\n", b)
}

これでは ORM ライブラリは事前に全ての struct を知った上で全ての struct に対して FindById メソッドを定義しないといけません。
今回はこの問題に対する 3 つの方法を紹介します。

reflect を使う方法

reflect とは 実行時に型情報や名前を取得したり、値を書き換えたりするためのパッケージです。実行時リフレクション (run-time reflection) と言ったほうがわかりやすいかもしれません。
これと interface{} を使うことによって先ほどの問題を解決できます。

拙作の ORM ライブラリ genmai もこの方法を使っています。

package main

type ORM struct{}

func (o *ORM) FindById(out interface{}, id int64) error {
    rv := reflect.ValueOf(out).Elem()
    rt := rv.Type()
    var names []string
    var args []interface{}
    for i := 0; i < rv.NumField(); i++ {
        names = append(names, rt.Field(i).Name)
        args = append(args, rv.Field(i).Addr().Interface())
    }
    columns := strings.Join(names, ",")
    return db.QueryRow(fmt.Sprintf("SELECT %s FROM %s WHERE Id = ?", columns, rt.Name()), id).Scan(args...)
}

func main() {
    orm := &ORM{}
    u := User{}
    if err := orm.FindById(&u, 1); err != nil {
        panic(err)
    }
    fmt.Printf("%#v\n", u)
    b := Blog{}
    if err := orm.FindById(&b, 1); err != nil {
        panic(err)
    }
    fmt.Printf("%#v\n", b)
}

Pros:

  • お手軽(簡単とは言ってない)
  • ライブラリだけで完結する

Cons:

  • 実行時に panic する可能性がある
  • interface{} を使うので静的型付きが有名無実化する

1 つ目は頑張ればなんとかできるとしても、2 つ目は静的型付き言語を使う以上できれば避けたいものです。

コード生成を使う方法

go:generate もとい、コード生成です。
はじめに書いた、全ての struct に対してメソッドを定義するというのをプログラムで行います。
対象となる struct の情報は、ソースファイルを AST に変換して解析するという方法が使われることが多いようです。
この際に出力するコードのテンプレートとしてよく使われるのが text/template で、genargen でも使われています。
なお下記のサンプルコードは簡略化のため struct の情報を直接書いています。

main.go
package main

import (
    "fmt"
    "go/build"
    "os"
    "strings"
    "text/template"
)

var tmpl = template.Must(template.New("").Parse(`package {{.Package}}

func (o {{.Name}}) FindById(id int64) (*{{.Name}}, error) {
    return &o, db.QueryRow("SELECT {{.Columns}} FROM {{.Name}} WHERE Id = ?", id).Scan({{.Fields}})
}
`))

func Generate(name string, columns ...string) error {
    pkg, err := build.Default.ImportDir(".", 0)
    if err != nil {
        panic(err)
    }
    fields := make([]string, 0, len(columns))
    for _, c := range columns {
        fields = append(fields, fmt.Sprintf("&o.%s", c))
    }
    f, err := os.Create(fmt.Sprintf("%s_gen.go", strings.ToLower(name)))
    if err != nil {
        return err
    }
    defer f.Close()
    return tmpl.Execute(f, map[string]interface{}{
        "Package": pkg.Name,
        "Name":    name,
        "Columns": strings.Join(columns, ", "),
        "Fields":  strings.Join(fields, ", "),
    })
}

func main() {
    if err := Generate("User", "Id", "Name", "Email"); err != nil {
        panic(err)
    }
    if err := Generate("Blog", "Id", "Title", "Body"); err != nil {
        panic(err)
    }
}

生成されるファイルは下記です。

user_gen.go
package main

func (o *User) FindById(id int64) error {
    return db.QueryRow("SELECT Id, Name, Email FROM User WHERE Id = ?", id).Scan(&o.Id, &o.Name, &o.Email)
}
blog_gen.go
package main

func (o *Blog) FindById(id int64) error {
    return db.QueryRow("SELECT Id, Title, Body FROM Blog WHERE Id = ?", id).Scan(&o.Id, &o.Title, &o.Body)
}

Pros:

  • 実行時に panic しない
  • 静的型付きのメリットを享受できる

Cons:

  • 変更があるたびにコードを生成しなおさなければいけない
  • 生成されるコードを事前にテストできない
  • テンプレートは文字列なのでコード補完やシンタックスチェッカーが使えない

2 つめは、コードを生成した後のコードにテストを書けば良さそうですが、

テンプレート変更 -> コード生成 -> テスト

という手順を踏む必要があり煩雑です。

AST からコード生成する方法

文字列テンプレートからのコード生成の問題を解決するために、実際に動くソースコードそれ自体をテンプレートにします。
実際に動くソースコード自体をテンプレートにするにはどうすればいいかというと、ソースコードを文字列として読み込んでプレースホルダーとなる識別子を strings.Replace、あるいは regexp でガリガリ書き換えていくという方法もあるにはありますが、文字列の中は書き換えたくないとか、コメントの中は書き換えたくないとか、あると思うので、AST を使います。

幸いにも Go には AST ライブラリ go/ast が標準で備わっていますので、これを使って頑張ることができます。

template.go
package main

func (o PlaceHolder) FindById(id int64) (*PlaceHolder, error) {
    return &o, db.QueryRow("SELECT "+o.columns()+" FROM User WHERE Id = ?", id).Scan(o.fields(&o)...)
}

func (o PlaceHolder) columns() string {
    return "Id, Name, Email"
}

func (o PlaceHolder) fields(p *PlaceHolder) []interface{} {
    return []interface{}{&o.Id, &o.Name, &o.Email}
}
main.go
package main

import (
    "fmt"
    "go/ast"
    "go/build"
    "go/format"
    "go/parser"
    "go/token"
    "os"
    "strconv"
    "strings"
)

func GenerateFromAST(name string, columns ...string) error {
    pkg, err := build.Default.ImportDir(".", 0)
    if err != nil {
        panic(err)
    }
    fset := token.NewFileSet()
    f, err := parser.ParseFile(fset, "template.go", nil, parser.ParseComments)
    f.Name.Name = pkg.Name
    rewriteASTNode(f, name, columns)
    file, err := os.Create(fmt.Sprintf("%s_gen.go", strings.ToLower(name)))
    if err != nil {
        return err
    }
    defer file.Close()
    return format.Node(file, fset, f)
}

func rewriteASTNode(f ast.Node, structName string, columns []string) {
    ast.Inspect(f, func(n ast.Node) bool {
        switch aType := n.(type) {
        case *ast.Ident:
            if strings.Contains(aType.Name, "PlaceHolder") {
                aType.Name = strings.Replace(aType.Name, "PlaceHolder", structName, 1)
            }
        case *ast.FuncDecl:
            if aType.Recv == nil {
                break
            }
            switch aType.Name.Name {
            case "columns":
                aType.Body.List[0].(*ast.ReturnStmt).Results[0].(*ast.BasicLit).Value = strconv.Quote(strings.Join(columns, ", "))
            case "fields":
                recvName := aType.Recv.List[0].Names[0].Name
                clit := aType.Body.List[0].(*ast.ReturnStmt).Results[0].(*ast.CompositeLit)
                clit.Elts = clit.Elts[:0]
                for _, c := range columns {
                    clit.Elts = append(clit.Elts, &ast.UnaryExpr{
                        Op: token.AND,
                        X: &ast.SelectorExpr{
                            X:   ast.NewIdent(recvName),
                            Sel: ast.NewIdent(c),
                        },
                    })
                }
            }
        }
        return true
    })
}

func main() {
    if err := GenerateFromAST("User", "Id", "Name", "Email"); err != nil {
        panic(err)
    }
    if err := GenerateFromAST("Blog", "Id", "Title", "Body"); err != nil {
        panic(err)
    }
}

生成されるファイルは下記です。

user_gen.go
package main

func (o User) FindById(id int64) (*User, error) {
    return &o, db.QueryRow("SELECT "+o.columns()+" FROM User WHERE Id = ?", id).Scan(o.fields(&o)...)
}

func (o User) columns() string {
    return "Id, Name, Email"
}

func (o User) fields(p *User) []interface{} {
    return []interface{}{&o.Id, &o.Name, &o.Email}
}
blog_gen.go
package main

func (o Blog) FindById(id int64) (*Blog, error) {
    return &o, db.QueryRow("SELECT "+o.columns()+" FROM User WHERE Id = ?", id).Scan(o.fields(&o)...)
}

func (o Blog) columns() string {
    return "Id, Title, Body"
}

func (o Blog) fields(p *Blog) []interface{} {
    return []interface{}{&o.Id, &o.Title, &o.Body}
}

Pros:

  • 実行時に panic しない
  • 静的型付きのメリットを享受できる
  • 生成されるコードを事前にテストすることができる
  • テンプレートを書くときにコード補完やシンタックスチェッカーを使える

Cons:

  • ご覧の通り全て AST で組み立てる必要があるため複雑
  • 書き換えられる部分の構造を少しでも変えると動かなくなる

まとめ

任意の型に対して共通の処理を提供する方法として reflect を使った方法、 text/template を使ったコード生成、AST からコード生成する方法の 3 つを紹介しました。
どれも一長一短ありますが、個人的には事前にテストができるのと、コード補完が効くという理由で AST からコード生成する方法が好きです。

これらの他にもより良い方法があればぜひお知らせください。