3
1

GoでAST使ってコード自動生成するぞ!

Last updated at Posted at 2023-12-17

CastingONE Advent Calendar 2023 17日目の記事です。

はじめに

実装を進めていると、ふと「なんか毎回同じようなこと書いているな」って思うことありません??
僕は結構あります!「repositoryのinterfaceで実装書く時、毎回traceの処理書いてるな〜」とか。
結構書き忘れもあるので、interfaceを元にこの辺りの処理全部生成できたら嬉しいと毎回思ってます笑

ということで、今回はASTを使ってうまくコード生成する方法を紹介します!
ASTとはなんぞやというところから実装まで紹介するので、ぜひ最後まで読んでいってください〜

ちなみに、本書の実装はこちらのリポジトリで公開しているので、詳しくみたいという方は参考にしてください!

ASTとは??

AST(Abstract Syntax Tree)はソースコードの構造を、ツリー構造で表現したものになります。このツリー構造はコードの構文的な構造を表します。プログラムの各要素、例えば変数や関数、演算子等が階層構造で表されるイメージです。

例えば以下のようなファイルを考えてみましょう!

user.go
package main

type User struct {
  Name string
}

こちらのファイルのコードについて、astの構造を出力してみると以下のようになります(必要な箇所のみ抜粋)。
ファイル内の構造がastのオブジェクトに反映されており、ツリー構造になっていることがなんとなくわかります。

     0  *ast.File {
     1  .  Package: 1
     2  .  Name: *ast.Ident {
     4  .  .  Name: "main"
     5  .  }
     6  .  Decls: []ast.Decl (len = 1) {
     7  .  .  0: *ast.GenDecl {
    11  .  .  .  Specs: []ast.Spec (len = 1) {
    12  .  .  .  .  0: *ast.TypeSpec {
    13  .  .  .  .  .  Name: *ast.Ident {
    15  .  .  .  .  .  .  Name: "Person"
    16  .  .  .  .  .  .  Obj: *ast.Object {
    17  .  .  .  .  .  .  .  Kind: type
    18  .  .  .  .  .  .  .  Name: "Person"
    20  .  .  .  .  .  .  }
    21  .  .  .  .  .  }
    23  .  .  .  .  .  Type: *ast.StructType {
    24  .  .  .  .  .  .  Struct: 27
    25  .  .  .  .  .  .  Fields: *ast.FieldList {
    27  .  .  .  .  .  .  .  List: []*ast.Field (len = 1) {
    28  .  .  .  .  .  .  .  .  0: *ast.Field {
    29  .  .  .  .  .  .  .  .  .  Names: []*ast.Ident (len = 1) {
    30  .  .  .  .  .  .  .  .  .  .  0: *ast.Ident {
    32  .  .  .  .  .  .  .  .  .  .  .  Name: "Name"
    33  .  .  .  .  .  .  .  .  .  .  .  Obj: *ast.Object {
    34  .  .  .  .  .  .  .  .  .  .  .  .  Kind: var
    35  .  .  .  .  .  .  .  .  .  .  .  .  Name: "Name"
    37  .  .  .  .  .  .  .  .  .  .  .  }
    38  .  .  .  .  .  .  .  .  .  .  }
    39  .  .  .  .  .  .  .  .  .  }
    40  .  .  .  .  .  .  .  .  .  Type: *ast.Ident {
    42  .  .  .  .  .  .  .  .  .  .  Name: "string"
    43  .  .  .  .  .  .  .  .  .  }
    66  }

このようなツリー構造を活用してファイルの内部構造を解析することで、自作リンターを開発したり、特定の要件に合わせたコードを生成したりすることが可能になります。

astの出力を確認する方法

上記の出力は、以下のようなコードで確認することができます。

main.go
package main

import (
	"go/ast"
	"go/parser"
	"go/token"
)

func main() {
	f, err := parser.ParseFile(token.NewFileSet(), "user.go", nil, 0)
	if err != nil {
		panic(err)
	}
	ast.Print(nil, f)
}

自動生成してみよう

ASTとは何かを見てきたところで、本題の自動生成の実装についてみていきます!
今回はrepsitoryのinterfaceの実装を出力してみます。

例えば、こんな感じのファイルから

package domain

import "context"

type User struct {
	ID    int
	Name  string
	Email string
}

type UserRepository interface {
	FindById(ctx context.Context, id int) (User, error)
	FindAll(ctx context.Context) ([]User, error)
	Store(ctx context.Context, user *User) error
}

以下のような出力ができるように実装してみます!

package infrastructure

import (
	"context"
	"log"

	"github.com/hiroaki-u/ast-playground/examples/domain"
)

type userRepository struct {
}

func NewUserRepository() domain.UserRepository {
	return &userRepository{}
}
func (r *userRepository) FindById(ctx context.Context, id int) (domain.User, error) {
	log.Default().Println("UserRepository.FindById")
	return domain.User{}, nil
}

func (r *userRepository) FindAll(ctx context.Context) ([]domain.User, error) {
	log.Default().Println("UserRepository.FindAll")
	return nil, nil
}

func (r *userRepository) Store(ctx context.Context, user *domain.User) error {
	log.Default().Println("UserRepository.Store")
	return nil
}

処理全体の流れ

今回実装した処理の流れをみていきます。大きく以下の通りです。

  • ASTを取得して、必要な情報を抽出する
  • 取得した情報から実装内容を出力する

処理は以下のようになっています
※今回はurfave/cliというcliライブラリを用いており、input_file等はcliのオプションとして受け取るようにしています

main.go
func createRepository(cCtx *cli.Context) {
	// inputファイルを取得する
	in := cCtx.String("input_file")
	if in == "" {
		log.Fatal("input_file is required")
		return
	}
	// outputファイルを取得する
	out := cCtx.String("output_file")
	if out == "" {
		log.Fatal("output_file is required")
		return
	}

	// inputファイルを読み込み、interfaceとそのメソッドの情報を取得する
	repo, err := parseRepositoryStructure(in)
	if err != nil {
		log.Fatal(err)
		return
	}

	// inputファイルのパッケージ名を取得
	pkgName, err := getPackageName(in)
	if err != nil {
		log.Fatal(err)
		return
	}

	// 取得したrepositoryの構造体から、repositoryの実装ファイルに記載する内容を作成する
	builder := NewRepositoryContentBuilder()
	ss, err := builder.Execute(repo, pkgName, getDirectoryName(out))
	if err != nil {
		log.Fatal(err)
		return
	}

	// outputファイルに書き込む
	p, err := os.Create(out)
	if err != nil {
		log.Fatal(err)
		return
	}
	defer p.Close()
	if _, err := p.Write([]byte(ss)); err != nil {
		log.Fatal(err)
		return
	}
}

再掲になりますが、おおまかな処理の流れは以下の2つですので、これから詳細を見ていこうと思います

  • ASTを取得して、必要な情報を抽出する
  • 取得した情報から実装内容を出力する

ASTを取得して、必要な情報を抽出する

ここではASTを取得して、欲しい情報(repositoryのinterface)を抽出する過程をご紹介します。
まずは今回ASTから取得した値を格納しておくstructを見ていきましょう。

builder.go
type Repository struct {
	Name    string
	Methods []*Method
}

type Method struct {
	Name    string
	Args    MethodValues
	Returns MethodValues
}

type MethodValues []MethodValue

type MethodValue struct {
	Type   *MethodType
	Values []string
}

type MethodType struct {
	isSlice        bool   // slice
	isPointer      bool   // ポインタ
	isVariadic     bool   // 可変長引数
	isPrimitive    bool   // 基本型
	requirePkgName bool   // package名が必要な場合
	Value          string // 型名
}

MethodTypeは結構シンプルに書いています。
あらゆる型に対応する場合はもっと色々考える必要はありますが、今回はシンプルな型のみ対象としました。

続いてmain.goから呼び出されていたparseRepositoryStructure関数を見ていきます。

parser.go
package main

import (
	"go/ast"
	"go/parser"
	"go/token"
)


// ファイルを走査し、ファイル内で見つかったrepositoryを返す
func parseRepositoryStructure(file string) (Repository, error) {
	// ファイルをパースして、astを取得する
	repo := Repository{}
	f, err := parser.ParseFile(token.NewFileSet(), file, nil, parser.Mode(0))
	if err != nil {
		return repo, err
	}

	// astを走査して、必要な情報を抽出
	ast.Inspect(f, func(n ast.Node) bool {
		methods := []*Method{}
		switch x := n.(type) {
		case *ast.TypeSpec:
			// 対象がinterfaceであるため、interfaceの情報を取得する
			it, ok := x.Type.(*ast.InterfaceType)
			if !ok {
				return true
			}
			repo.Name = x.Name.Name

			// 関数の情報を取得
			for _, field := range it.Methods.List {
				funcType, ok := field.Type.(*ast.FuncType)
				if !ok {
					continue
				}
				method := Method{}
				method.Name = field.Names[0].Name

				// 引数と返り値を取得する
				method.Args = ExtractMethodValues(funcType.Params.List)
				method.Returns = ExtractMethodValues(funcType.Results.List)
				methods = append(methods, &method)
			}
			repo.Methods = methods
		}
		return true
	})
	return repo, nil
}

最初にgo/parserParseFile関数を用いて該当ファイルのastを取得しています。ParseFile関数の返り値の型は*ast.Fileです。

ちなみにast.Fileは以下のような構造体となっています。確かにファイル内の情報が詰め込まれていそうです。

go/ast/ast.go
type File struct {
    Doc     *CommentGroup // 関連するドキュメント;ない場合はnil
    Package token.Pos     // "package" キーワードの位置
    Name    *Ident        // パッケージ名
    Decls   []Decl        // トップレベルの宣言;ない場合はnil

    FileStart, FileEnd token.Pos       // ファイル全体の開始と終了位置
    Scope              *Scope          // パッケージスコープ(このファイルのみ)
    Imports            []*ImportSpec   // このファイルのインポート
    Unresolved         []*Ident        // このファイル内の未解決識別子
    Comments           []*CommentGroup // ソースファイル内のすべてのコメントのリスト
    GoVersion          string          //go:build や +build ディレクティブによる要求される最小Goバージョン
}

この構造体は、ast.Nodeというinterfaceを実装しています。
astには色々なオブジェクトが出てきますが、基本的にはast.Nodeというinterfaceを実装しており、それを用いて繰り返し構造を成り立たせています。
少し脱線したので、今回の処理の話を戻しましょう!

ast.Fileを取得したら、Inspect関数を使って走査して構造を捉えていきます。

ちなみにInspect関数は以下のような実装となっており、Nodeを実装しているstructについて走査処理を実行します。

go/ast/walk.go
func Inspect(node Node, f func(Node) bool) {
	Walk(inspector(f), node)
}

Inspect関数の第二引数は関数であるため、この部分で情報の抽出を行っています。
今回はこの第二引数の関数にて、必要となる情報の抽出をしてます。
astにより取得したい対象がinterfaceであるため、それに該当するInterfaceTypeというNodeのみを取得するようにしました。

ast.InterfaceTypeは以下のような構造をもちます。Methodsというフィールドがinterfaceがもつメソッドの情報を保持しているため、その後の処理にてMethodsからメソッドの情報を取得していきます。

go/ast/ast.go
// An InterfaceType node represents an interface type.
InterfaceType struct {
	Interface  token.Pos  // position of "interface" keyword
	Methods    *FieldList // list of embedded interfaces, methods, or types
	Incomplete bool       // true if (source) methods or types are missing in the Methods list
}

メソッドの引数や返り値はExtractMethodValues関数にて処理するようにしました。詳細は以下の通りです。

parser.go
func ExtractMethodValues(list []*ast.Field) []MethodValue {
	mvs := []MethodValue{}
	for _, param := range list {
		mv := MethodValue{}
		mt := &MethodType{}
		mv.Type = IdentifyNodeType(param.Type, mt)
		if param.Names != nil {
			for _, p := range param.Names {
				mv.AppendValue(p.Name)
			}
		}
		mvs = append(mvs, mv)
	}
	return mvs
}

個々のメソッドについては、IdentifyNodeType関数にて型の分類をしています。詳細を見ていきましょう。

parser.go

// 各引数や返り値の型を特定して、MethodTypeに格納する
func IdentifyNodeType(t ast.Expr, mt *MethodType) {
	switch t.(type) {
	// sliceの場合
	case *ast.ArrayType:
		se := t.(*ast.ArrayType).Elt
		mt.isSlice = true
		IdentifyNodeType(se, mt)
	// pointer型の場合
	case *ast.StarExpr:
		se, _ := t.(*ast.StarExpr).X.(*ast.Ident)
		if se != nil {
			if !isPrimitive(se) {
				mt.isPointer = true
				mt.requirePkgName = true
				mt.Value = se.Name
			}
		} else {
			se, _ := t.(*ast.StarExpr).X.(*ast.SelectorExpr)
			x := se.X.(*ast.Ident)
			sel := se.Sel
			mt.isPointer = true
			mt.Value = x.Name + "." + sel.Name
		}
	// シンプルな型の場合(primitive型やstruct)
	case *ast.Ident:
		se := t.(*ast.Ident)
		mt.Value = se.Name
		if !isPrimitive(t.(*ast.Ident)) {
			mt.requirePkgName = true
		}
	// package + structの場合
	case *ast.SelectorExpr:
		x := t.(*ast.SelectorExpr).X.(*ast.Ident)
		sel := t.(*ast.SelectorExpr).Sel
		mt.Value = x.Name + "." + sel.Name
	// 可変引数
	case *ast.Ellipsis:
		se := t.(*ast.Ellipsis).Elt
		mt.Value = "..." + se.(*ast.Ident).Name
	default:
	}
}

引数のast.Exprは式ノード(Expression Node)を表すinterfaceです。
ちなみに式のノードは値を生成するプログラムを指します。関数の呼び出しだけでなく、値自体もそれに該当します。

IdentifyNodeType関数では、InterfaceType.Methodsに入っていた値の型ごとに処理を分類しています。
他にも色々な型はありますが、今回はシンプルな型のみを扱うようにしました。
以上がastで構造を取得するコードでした!

取得した情報から実装内容を出力する

次にastによって取得した値を出力していきます。
astutilを用いても実装できると思いますが、今回は構造の変更しやすさを考えてtemplate使って出力する方法を選択しました。

最終的な出力は以下のような形でしたので、まずはこれに該当するテンプレートを作成します。

package infrastructure

import (
	"context"
	"log"

	"github.com/hiroaki-u/ast-playground/examples/domain"
)

type userRepository struct {
}

func NewUserRepository() domain.UserRepository {
	return &userRepository{}
}
func (r *userRepository) FindById(ctx context.Context, id int) (domain.User, error) {
	log.Default().Println("UserRepository.FindById")
 
	return domain.User{}, nil
}

func (r *userRepository) FindAll(ctx context.Context) ([]domain.User, error) {
	log.Default().Println("UserRepository.FindAll")
 
	return nil, nil
}

func (r *userRepository) Store(ctx context.Context, user *domain.User) error {
	log.Default().Println("UserRepository.Store")
 
	return nil
}

用意したテンプレートはこちらの2つ。

new_func.tpl
package {{ .Package }}

type {{ .LowerName }} struct {
}

func New{{ .Name }}() {{ .InterfacePackage }}.{{ .Name }} {
	return & {{ .LowerName }}{}
}
method.tpl
func ({{ .ReceiverValue }} *{{ .ReceiverType }}) {{ .MethodName }}({{ .Args }}) ({{ .ReturnArgs }}) {
  {{ .Body.Log }}

  {{ .Body.ReturnValue }}
}

これらのファイルはembedを使ってtemplateを用意しておきます。

builder.go
var (
	//go:embed templates/*
	templates embed.FS

	// 関数用テンプレート
	methodTemplate = template.Must(template.ParseFS(templates, "templates/method.tpl"))
	// New関数用テンプレート
	newFuncTemplate = template.Must(template.ParseFS(templates, "templates/factory.tpl"))
)

Repository生成専用のstructにこれらのtemplateをもたせて、methodとnew関数の生成処理を実行しました。

builder.go
type RepositoryContentBuilder struct {
	MethodTemplate  *template.Template
	NewFuncTemplate *template.Template
}

func NewRepositoryContentBuilder() *RepositoryContentBuilder {
	return &RepositoryContentBuilder{
		MethodTemplate:  methodTemplate,
		NewFuncTemplate: newFuncTemplate,
	}
}

methodとnew関数の生成過程はほぼ一緒なので、今回はmethodを例に処理をみていきます。

builder.go

// メソッドの作成
func (rc *RepositoryContentBuilder) createMethod(repo Repository, pkgName string) ([]string, error) {
	res := []string{}

	for _, method := range repo.Methods {
		returnList := []string{}
		for _, v := range method.Returns {
			returnList = append(returnList, v.Type.getZeroValue(pkgName))
		}
		body := methodBodyParameter{}
		body.ReturnValue = "return " + strings.Join(returnList, ", ")
		body.Log = `log.Default().Println("` + repo.Name + "." + method.Name + `")`

		content := &methodParameter{
			ReceiverValue: "r",
			ReceiverType:  repo.getLowerName(),
			MethodName:    method.Name,
			Args:          method.Args.GetTemplate(pkgName),
			ReturnArgs:    method.Returns.GetTemplate(pkgName),
			Body:          body,
		}
		var buf bytes.Buffer
		if err := rc.MethodTemplate.Execute(&buf, content); err != nil {
			return nil, err
		}
		res = append(res, buf.String())
	}
	return res, nil
}

type methodParameter struct {
	ReceiverValue string
	ReceiverType  string
	MethodName    string
	Args          string
	ReturnArgs    string
	Body          methodBodyParameter
}

type methodBodyParameter struct {
	Log         string
	ReturnValue string
}

基本的には用意したtemplateの値を構築するだけなので、あまりややこしい処理はありません。
返り値の部分のみ該当する型のゼロ値が必要になってくるため、MethodTypeという型にgetZeroValueというメソッドを用意してそれを使うようにしました。

おわりに

以上GoのAbstract Syntax Tree(AST)を利用したコードの自動生成でした!
簡単なコード生成はIDEの機能内で行えることもありますがASTの実装を活用することで、より複雑な生成プロセスにも対応することができます。
気になる方はぜひぜひastで遊んでみてください!

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1