CastingONE Advent Calendar 2023 17日目の記事です。
はじめに
実装を進めていると、ふと「なんか毎回同じようなこと書いているな」って思うことありません??
僕は結構あります!「repositoryのinterfaceで実装書く時、毎回traceの処理書いてるな〜」とか。
結構書き忘れもあるので、interfaceを元にこの辺りの処理全部生成できたら嬉しいと毎回思ってます笑
ということで、今回はASTを使ってうまくコード生成する方法を紹介します!
ASTとはなんぞやというところから実装まで紹介するので、ぜひ最後まで読んでいってください〜
ちなみに、本書の実装はこちらのリポジトリで公開しているので、詳しくみたいという方は参考にしてください!
ASTとは??
AST(Abstract Syntax Tree)はソースコードの構造を、ツリー構造で表現したものになります。このツリー構造はコードの構文的な構造を表します。プログラムの各要素、例えば変数や関数、演算子等が階層構造で表されるイメージです。
例えば以下のようなファイルを考えてみましょう!
package main
type User struct {
Name string
}
こちらのファイルのコードについて、astの構造を出力してみると以下のようになります(必要な箇所のみ抜粋)。
ファイル内の構造がastのオブジェクトに反映されており、ツリー構造になっていることがなんとなくわかります。
0 *ast.File {
1 . Package: 1
2 . Name: *ast.Ident {
4 . . Name: "main"
5 . }
6 . Decls: []ast.Decl (len = 1) {
7 . . 0: *ast.GenDecl {
11 . . . Specs: []ast.Spec (len = 1) {
12 . . . . 0: *ast.TypeSpec {
13 . . . . . Name: *ast.Ident {
15 . . . . . . Name: "Person"
16 . . . . . . Obj: *ast.Object {
17 . . . . . . . Kind: type
18 . . . . . . . Name: "Person"
20 . . . . . . }
21 . . . . . }
23 . . . . . Type: *ast.StructType {
24 . . . . . . Struct: 27
25 . . . . . . Fields: *ast.FieldList {
27 . . . . . . . List: []*ast.Field (len = 1) {
28 . . . . . . . . 0: *ast.Field {
29 . . . . . . . . . Names: []*ast.Ident (len = 1) {
30 . . . . . . . . . . 0: *ast.Ident {
32 . . . . . . . . . . . Name: "Name"
33 . . . . . . . . . . . Obj: *ast.Object {
34 . . . . . . . . . . . . Kind: var
35 . . . . . . . . . . . . Name: "Name"
37 . . . . . . . . . . . }
38 . . . . . . . . . . }
39 . . . . . . . . . }
40 . . . . . . . . . Type: *ast.Ident {
42 . . . . . . . . . . Name: "string"
43 . . . . . . . . . }
66 }
このようなツリー構造を活用してファイルの内部構造を解析することで、自作リンターを開発したり、特定の要件に合わせたコードを生成したりすることが可能になります。
astの出力を確認する方法
上記の出力は、以下のようなコードで確認することができます。
package main
import (
"go/ast"
"go/parser"
"go/token"
)
func main() {
f, err := parser.ParseFile(token.NewFileSet(), "user.go", nil, 0)
if err != nil {
panic(err)
}
ast.Print(nil, f)
}
自動生成してみよう
ASTとは何かを見てきたところで、本題の自動生成の実装についてみていきます!
今回はrepsitoryのinterfaceの実装を出力してみます。
例えば、こんな感じのファイルから
package domain
import "context"
type User struct {
ID int
Name string
Email string
}
type UserRepository interface {
FindById(ctx context.Context, id int) (User, error)
FindAll(ctx context.Context) ([]User, error)
Store(ctx context.Context, user *User) error
}
以下のような出力ができるように実装してみます!
package infrastructure
import (
"context"
"log"
"github.com/hiroaki-u/ast-playground/examples/domain"
)
type userRepository struct {
}
func NewUserRepository() domain.UserRepository {
return &userRepository{}
}
func (r *userRepository) FindById(ctx context.Context, id int) (domain.User, error) {
log.Default().Println("UserRepository.FindById")
return domain.User{}, nil
}
func (r *userRepository) FindAll(ctx context.Context) ([]domain.User, error) {
log.Default().Println("UserRepository.FindAll")
return nil, nil
}
func (r *userRepository) Store(ctx context.Context, user *domain.User) error {
log.Default().Println("UserRepository.Store")
return nil
}
処理全体の流れ
今回実装した処理の流れをみていきます。大きく以下の通りです。
- ASTを取得して、必要な情報を抽出する
- 取得した情報から実装内容を出力する
処理は以下のようになっています
※今回はurfave/cliというcliライブラリを用いており、input_file等はcliのオプションとして受け取るようにしています
func createRepository(cCtx *cli.Context) {
// inputファイルを取得する
in := cCtx.String("input_file")
if in == "" {
log.Fatal("input_file is required")
return
}
// outputファイルを取得する
out := cCtx.String("output_file")
if out == "" {
log.Fatal("output_file is required")
return
}
// inputファイルを読み込み、interfaceとそのメソッドの情報を取得する
repo, err := parseRepositoryStructure(in)
if err != nil {
log.Fatal(err)
return
}
// inputファイルのパッケージ名を取得
pkgName, err := getPackageName(in)
if err != nil {
log.Fatal(err)
return
}
// 取得したrepositoryの構造体から、repositoryの実装ファイルに記載する内容を作成する
builder := NewRepositoryContentBuilder()
ss, err := builder.Execute(repo, pkgName, getDirectoryName(out))
if err != nil {
log.Fatal(err)
return
}
// outputファイルに書き込む
p, err := os.Create(out)
if err != nil {
log.Fatal(err)
return
}
defer p.Close()
if _, err := p.Write([]byte(ss)); err != nil {
log.Fatal(err)
return
}
}
再掲になりますが、おおまかな処理の流れは以下の2つですので、これから詳細を見ていこうと思います
- ASTを取得して、必要な情報を抽出する
- 取得した情報から実装内容を出力する
ASTを取得して、必要な情報を抽出する
ここではASTを取得して、欲しい情報(repositoryのinterface)を抽出する過程をご紹介します。
まずは今回ASTから取得した値を格納しておくstructを見ていきましょう。
type Repository struct {
Name string
Methods []*Method
}
type Method struct {
Name string
Args MethodValues
Returns MethodValues
}
type MethodValues []MethodValue
type MethodValue struct {
Type *MethodType
Values []string
}
type MethodType struct {
isSlice bool // slice
isPointer bool // ポインタ
isVariadic bool // 可変長引数
isPrimitive bool // 基本型
requirePkgName bool // package名が必要な場合
Value string // 型名
}
MethodType
は結構シンプルに書いています。
あらゆる型に対応する場合はもっと色々考える必要はありますが、今回はシンプルな型のみ対象としました。
続いてmain.goから呼び出されていたparseRepositoryStructure
関数を見ていきます。
package main
import (
"go/ast"
"go/parser"
"go/token"
)
// ファイルを走査し、ファイル内で見つかったrepositoryを返す
func parseRepositoryStructure(file string) (Repository, error) {
// ファイルをパースして、astを取得する
repo := Repository{}
f, err := parser.ParseFile(token.NewFileSet(), file, nil, parser.Mode(0))
if err != nil {
return repo, err
}
// astを走査して、必要な情報を抽出
ast.Inspect(f, func(n ast.Node) bool {
methods := []*Method{}
switch x := n.(type) {
case *ast.TypeSpec:
// 対象がinterfaceであるため、interfaceの情報を取得する
it, ok := x.Type.(*ast.InterfaceType)
if !ok {
return true
}
repo.Name = x.Name.Name
// 関数の情報を取得
for _, field := range it.Methods.List {
funcType, ok := field.Type.(*ast.FuncType)
if !ok {
continue
}
method := Method{}
method.Name = field.Names[0].Name
// 引数と返り値を取得する
method.Args = ExtractMethodValues(funcType.Params.List)
method.Returns = ExtractMethodValues(funcType.Results.List)
methods = append(methods, &method)
}
repo.Methods = methods
}
return true
})
return repo, nil
}
最初にgo/parser
のParseFile
関数を用いて該当ファイルのastを取得しています。ParseFile
関数の返り値の型は*ast.File
です。
ちなみにast.File
は以下のような構造体となっています。確かにファイル内の情報が詰め込まれていそうです。
type File struct {
Doc *CommentGroup // 関連するドキュメント;ない場合はnil
Package token.Pos // "package" キーワードの位置
Name *Ident // パッケージ名
Decls []Decl // トップレベルの宣言;ない場合はnil
FileStart, FileEnd token.Pos // ファイル全体の開始と終了位置
Scope *Scope // パッケージスコープ(このファイルのみ)
Imports []*ImportSpec // このファイルのインポート
Unresolved []*Ident // このファイル内の未解決識別子
Comments []*CommentGroup // ソースファイル内のすべてのコメントのリスト
GoVersion string //go:build や +build ディレクティブによる要求される最小Goバージョン
}
この構造体は、ast.Node
というinterfaceを実装しています。
astには色々なオブジェクトが出てきますが、基本的にはast.Node
というinterfaceを実装しており、それを用いて繰り返し構造を成り立たせています。
少し脱線したので、今回の処理の話を戻しましょう!
ast.Fileを取得したら、Inspect
関数を使って走査して構造を捉えていきます。
ちなみにInspect
関数は以下のような実装となっており、Node
を実装しているstructについて走査処理を実行します。
func Inspect(node Node, f func(Node) bool) {
Walk(inspector(f), node)
}
Inspect
関数の第二引数は関数であるため、この部分で情報の抽出を行っています。
今回はこの第二引数の関数にて、必要となる情報の抽出をしてます。
astにより取得したい対象がinterfaceであるため、それに該当するInterfaceType
というNode
のみを取得するようにしました。
ast.InterfaceType
は以下のような構造をもちます。Methods
というフィールドがinterfaceがもつメソッドの情報を保持しているため、その後の処理にてMethods
からメソッドの情報を取得していきます。
// An InterfaceType node represents an interface type.
InterfaceType struct {
Interface token.Pos // position of "interface" keyword
Methods *FieldList // list of embedded interfaces, methods, or types
Incomplete bool // true if (source) methods or types are missing in the Methods list
}
メソッドの引数や返り値はExtractMethodValues
関数にて処理するようにしました。詳細は以下の通りです。
func ExtractMethodValues(list []*ast.Field) []MethodValue {
mvs := []MethodValue{}
for _, param := range list {
mv := MethodValue{}
mt := &MethodType{}
mv.Type = IdentifyNodeType(param.Type, mt)
if param.Names != nil {
for _, p := range param.Names {
mv.AppendValue(p.Name)
}
}
mvs = append(mvs, mv)
}
return mvs
}
個々のメソッドについては、IdentifyNodeType
関数にて型の分類をしています。詳細を見ていきましょう。
// 各引数や返り値の型を特定して、MethodTypeに格納する
func IdentifyNodeType(t ast.Expr, mt *MethodType) {
switch t.(type) {
// sliceの場合
case *ast.ArrayType:
se := t.(*ast.ArrayType).Elt
mt.isSlice = true
IdentifyNodeType(se, mt)
// pointer型の場合
case *ast.StarExpr:
se, _ := t.(*ast.StarExpr).X.(*ast.Ident)
if se != nil {
if !isPrimitive(se) {
mt.isPointer = true
mt.requirePkgName = true
mt.Value = se.Name
}
} else {
se, _ := t.(*ast.StarExpr).X.(*ast.SelectorExpr)
x := se.X.(*ast.Ident)
sel := se.Sel
mt.isPointer = true
mt.Value = x.Name + "." + sel.Name
}
// シンプルな型の場合(primitive型やstruct)
case *ast.Ident:
se := t.(*ast.Ident)
mt.Value = se.Name
if !isPrimitive(t.(*ast.Ident)) {
mt.requirePkgName = true
}
// package + structの場合
case *ast.SelectorExpr:
x := t.(*ast.SelectorExpr).X.(*ast.Ident)
sel := t.(*ast.SelectorExpr).Sel
mt.Value = x.Name + "." + sel.Name
// 可変引数
case *ast.Ellipsis:
se := t.(*ast.Ellipsis).Elt
mt.Value = "..." + se.(*ast.Ident).Name
default:
}
}
引数のast.Expr
は式ノード(Expression Node)を表すinterfaceです。
ちなみに式のノードは値を生成するプログラムを指します。関数の呼び出しだけでなく、値自体もそれに該当します。
IdentifyNodeType
関数では、InterfaceType.Methods
に入っていた値の型ごとに処理を分類しています。
他にも色々な型はありますが、今回はシンプルな型のみを扱うようにしました。
以上がastで構造を取得するコードでした!
取得した情報から実装内容を出力する
次にastによって取得した値を出力していきます。
astutilを用いても実装できると思いますが、今回は構造の変更しやすさを考えてtemplate使って出力する方法を選択しました。
最終的な出力は以下のような形でしたので、まずはこれに該当するテンプレートを作成します。
package infrastructure
import (
"context"
"log"
"github.com/hiroaki-u/ast-playground/examples/domain"
)
type userRepository struct {
}
func NewUserRepository() domain.UserRepository {
return &userRepository{}
}
func (r *userRepository) FindById(ctx context.Context, id int) (domain.User, error) {
log.Default().Println("UserRepository.FindById")
return domain.User{}, nil
}
func (r *userRepository) FindAll(ctx context.Context) ([]domain.User, error) {
log.Default().Println("UserRepository.FindAll")
return nil, nil
}
func (r *userRepository) Store(ctx context.Context, user *domain.User) error {
log.Default().Println("UserRepository.Store")
return nil
}
用意したテンプレートはこちらの2つ。
package {{ .Package }}
type {{ .LowerName }} struct {
}
func New{{ .Name }}() {{ .InterfacePackage }}.{{ .Name }} {
return & {{ .LowerName }}{}
}
func ({{ .ReceiverValue }} *{{ .ReceiverType }}) {{ .MethodName }}({{ .Args }}) ({{ .ReturnArgs }}) {
{{ .Body.Log }}
{{ .Body.ReturnValue }}
}
これらのファイルはembedを使ってtemplateを用意しておきます。
var (
//go:embed templates/*
templates embed.FS
// 関数用テンプレート
methodTemplate = template.Must(template.ParseFS(templates, "templates/method.tpl"))
// New関数用テンプレート
newFuncTemplate = template.Must(template.ParseFS(templates, "templates/factory.tpl"))
)
Repository生成専用のstructにこれらのtemplateをもたせて、methodとnew関数の生成処理を実行しました。
type RepositoryContentBuilder struct {
MethodTemplate *template.Template
NewFuncTemplate *template.Template
}
func NewRepositoryContentBuilder() *RepositoryContentBuilder {
return &RepositoryContentBuilder{
MethodTemplate: methodTemplate,
NewFuncTemplate: newFuncTemplate,
}
}
methodとnew関数の生成過程はほぼ一緒なので、今回はmethodを例に処理をみていきます。
// メソッドの作成
func (rc *RepositoryContentBuilder) createMethod(repo Repository, pkgName string) ([]string, error) {
res := []string{}
for _, method := range repo.Methods {
returnList := []string{}
for _, v := range method.Returns {
returnList = append(returnList, v.Type.getZeroValue(pkgName))
}
body := methodBodyParameter{}
body.ReturnValue = "return " + strings.Join(returnList, ", ")
body.Log = `log.Default().Println("` + repo.Name + "." + method.Name + `")`
content := &methodParameter{
ReceiverValue: "r",
ReceiverType: repo.getLowerName(),
MethodName: method.Name,
Args: method.Args.GetTemplate(pkgName),
ReturnArgs: method.Returns.GetTemplate(pkgName),
Body: body,
}
var buf bytes.Buffer
if err := rc.MethodTemplate.Execute(&buf, content); err != nil {
return nil, err
}
res = append(res, buf.String())
}
return res, nil
}
type methodParameter struct {
ReceiverValue string
ReceiverType string
MethodName string
Args string
ReturnArgs string
Body methodBodyParameter
}
type methodBodyParameter struct {
Log string
ReturnValue string
}
基本的には用意したtemplateの値を構築するだけなので、あまりややこしい処理はありません。
返り値の部分のみ該当する型のゼロ値が必要になってくるため、MethodType
という型にgetZeroValueというメソッドを用意してそれを使うようにしました。
おわりに
以上GoのAbstract Syntax Tree(AST)を利用したコードの自動生成でした!
簡単なコード生成はIDEの機能内で行えることもありますがASTの実装を活用することで、より複雑な生成プロセスにも対応することができます。
気になる方はぜひぜひastで遊んでみてください!