LoginSignup
3
3

More than 1 year has passed since last update.

【Go言語】静的解析で構造体のフィールドを参照している機能を調べたい

Last updated at Posted at 2021-12-16

前置き

ある日顧客から、こんな依頼を受けました。

「あるDBテーブルのカラムがそれぞれどの機能で参照されているか調べて欲しい」

OR Mapper的な仕組みを利用している為、カラムは全てある構造体のフィールドとして定義されています。
参照されているコードの個所を調べるだけならばgrepでもある程度可能そうですが、grepには下記問題があります。

  • 同名の変数や他構造体の同名フィールドもhitしてしまう
  • 参照されている箇所しか分からず、その処理(関数)を呼び出している関数を辿れない

エディタやIDEの機能を使い変数参照元と関数呼び出し元を辿ることも出来そうですが、何しろ対象の構造体は フィールド数が数百個ある 為、手作業でやっていたら相当時間がかかりそうです(´・ω・`)

そこで、静的解析プログラムを作成して、構造体フィールドとそれを呼び出している機能(関数)を自動で出力することを思い立ちました。

解析対象プロジェクト(サンプル)

パッケージ構成

sample/
| api/
| | api.go
| domain/
| | domain.go
| service/
| | service.go
| go.mod

パッケージの依存関係は api ← service ← domain となっています。
解析対象の構造体はdomainパッケージに定義されています。
apiパッケージに定義されているexportedな関数が機能を表す為、対象構造体各フィールドがどのapiパッケージ関数を起点にした処理で参照しているかを出力することをゴールとします。

コード

domain.go
package domain

import (
    "fmt"
    "time"
)

type Hoge struct {
    Field1 string
    Field2 int
    Field3 time.Time
    Field4 string
    Field5 int
    Field6 time.Time
    // 実際はフィールドが700個以上ある・・
}

func (h Hoge) Call() {
    fmt.Println(h.Field1, h.Field2)
}

func CallHoge(h Hoge) {
    fmt.Println(h.Field3, h.Field4)
}

type Fuga struct {
    Name string
    Age  int
}
service.go
package service

import (
    "fmt"
    "sample/domain"
    "time"
)

func ServeA() {
    hoge := domain.Hoge{}
    hoge.Call()
}

func ServeB() {
    hoge := domain.Hoge{}
    domain.CallHoge(hoge)
}

func ServeC() {
    hoge := domain.Hoge{
        Field5: 999,
        Field6: time.Now(),
    }
    fmt.Println(hoge)
}
api.go
package api

import (
    "fmt"
    "sample/domain"
    "sample/service"
)

func FeatureA() {
    service.ServeA()
}

func FeatureB() {
    hoge := domain.Hoge{
        Field1: "test",
    }
    fmt.Println(hoge)

    service.ServeB()
}

func FeatureC() {
    subFeature()
}

func subFeature() {
    hoge := domain.Hoge{
        Field2: 123,
    }
    fmt.Println(hoge)

    service.ServeC()
}

func FeatureAll() {
    FeatureA()
    FeatureB()
    FeatureC()
}

解析プログラム作成

雛形作成

skeletonというツールを利用して静的解析ツール雛形をカンタンに生成できます。
https://github.com/gostaticanalysis/skeleton
※インストール方法は上記リポジトリのREADMEを参照

生成

$ skeleton sfused

※sfusedは、find struct fields used をテキトーに省略した名前です😓

下記の様なプロジェクト雛形が生成されます。

sfused/
| cmd/
| | sfused/
| | | main.go
| testdata/
| go.mod
| sfused.go
| sfused_test.go

go.modを修正

golang.org/x/tools/go/callgraph/vta のバージョンが0.1.6以上だと謎のpanicが起きたので、日和って0.1.5に下げました。

go.mod
@@ -4,5 +4,5 @@ go 1.16

 require (
        github.com/gostaticanalysis/testutil v0.4.0
-       golang.org/x/tools v0.1.8
+       golang.org/x/tools v0.1.5
 )

main関数を修正

もともとgo vetから呼び出されることを想定したunitcheckerが使われていましたが、コマンド実行を想定したsinglecheckerに変更します。

main.go
@@ -3,7 +3,7 @@ package main
 import (
        "sfused"

-       "golang.org/x/tools/go/analysis/unitchecker"
+       "golang.org/x/tools/go/analysis/singlechecker"
 )

-func main() { unitchecker.Main(sfused.Analyzer) }
+func main() { singlechecker.Main(sfused.Analyzer) }

Analyzerを修正

詳細は割愛しますが、ざっくり説明すると下記を行なっています

  1. 対象構造体のフィールド一覧を収集
  2. どの関数が別のどの関数を呼び出しているかを解析
  3. 対象構造体フィールドがどの関数で参照されているかを解析

そしてそれぞれの解析結果をマージして出力します。

sfused.go
package sfused

import (
    "fmt"
    "go/ast"
    "go/types"
    "strings"

    "golang.org/x/tools/go/analysis"
    "golang.org/x/tools/go/analysis/passes/buildssa"
    "golang.org/x/tools/go/analysis/passes/inspect"
    "golang.org/x/tools/go/ast/inspector"
    "golang.org/x/tools/go/callgraph"
    "golang.org/x/tools/go/callgraph/cha"
    "golang.org/x/tools/go/callgraph/vta"
    "golang.org/x/tools/go/ssa"
    "golang.org/x/tools/go/ssa/ssautil"
)

const doc = "sfused finds which feature uses specific struct fields"

var Analyzer = &analysis.Analyzer{
    Name: "sfused",
    Doc:  doc,
    Run:  run,
    Requires: []*analysis.Analyzer{
        inspect.Analyzer,
        buildssa.Analyzer,
    },
    FactTypes: []analysis.Fact{
        new(targetFieldsFact),
        new(calleesFact),
        new(refersFact),
    },
}

// とりいそぎ対象の構造体やパッケージは定数で指定
// TODO: フラグで渡す
const (
    rootPkg      = "sample"             // root package
    apiPkg       = "sample/api"         // API package
    targetStruct = "sample/domain.Hoge" // target struct name
)

var structPkg, structName string

func init() {
    pkgAndName := strings.SplitN(targetStruct, ".", 2)
    structPkg, structName = pkgAndName[0], pkgAndName[1]
}

// Package Fact

type targetFieldsFact map[*types.Var]struct{} // 対象構造体フィールドセット

func (*targetFieldsFact) AFact() {}

// Object Facts

type calleesFact map[types.Object]struct{} // 呼び出し関数セット

func (*calleesFact) AFact() {}

type refersFact map[types.Object]struct{} // 参照フィールドセット

func (*refersFact) AFact() {}

func run(pass *analysis.Pass) (interface{}, error) {

    if !strings.HasPrefix(pass.Pkg.Path(), rootPkg) {
        return nil, nil
    }

    inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)

    // #### 解析対象構造体フィールドリストを収集 ####
    // 型情報から取得
    var targetFields targetFieldsFact
    if pass.Pkg.Path() == structPkg {
        targetFields = make(targetFieldsFact)

        var bt types.Type
        inspect.Preorder([]ast.Node{(*ast.TypeSpec)(nil)}, func(n ast.Node) {
            switch n := n.(type) {
            case *ast.TypeSpec:
                if n.Name.String() == structName {
                    bt = pass.TypesInfo.TypeOf(n.Name)
                }
            }
        })
        if bt == nil {
            return nil, fmt.Errorf("struct %s not found", targetStruct)
        }
        ubt := bt.Underlying()

        st := ubt.(*types.Struct)

        stSet := make(map[*types.Struct]struct{})
        stSet[st] = struct{}{}

        sfs := collectFields(st, stSet)
        for _, sf := range sfs {
            targetFields[sf] = struct{}{}
        }

        pass.ExportPackageFact(&targetFields)
    } else {
        for _, pf := range pass.AllPackageFacts() {
            if pf.Package.Path() == structPkg {
                pass.ImportPackageFact(pf.Package, &targetFields)
                break
            }
        }
    }

    // #### どの関数が別の何の関数を呼び出しているか解析 ####
    // callgraph(VTAアルゴリズム)を利用
    calleesFacts := make(map[types.Object]calleesFact)
    s := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA)
    initcg := cha.CallGraph(s.Pkg.Prog)
    cg := vta.CallGraph(ssautil.AllFunctions(s.Pkg.Prog), initcg)
    if err := callgraph.GraphVisitEdges(cg, func(edge *callgraph.Edge) error {

        // 解析中パッケージと異なるパッケージのCallerを除外
        if edge.Caller.Func.Package() == nil || edge.Caller.Func.Package().Pkg.Path() != pass.Pkg.Path() {
            return nil
        }

        // プロジェクト外パッケージ関数の呼び出しを除外
        if !strings.HasPrefix(edge.Callee.Func.Package().Pkg.Path(), rootPkg) {
            return nil
        }

        // init関数の呼び出しは無視
        if edge.Caller.Func.Name() == "init" || edge.Callee.Func.Name() == "init" {
            return nil
        }

        // Calleeがnested functionだったら無視
        if edge.Callee.Func.Parent() != nil {
            return nil
        }

        // Callerがnested functionだったらtop-level functionに置き換える
        callerFunc := topLevelFunc(edge.Caller.Func)

        // validate
        if callerFunc.Object() == nil {
            return fmt.Errorf("nil object caller : %v\n", edge.Caller.Func)
        }
        if edge.Callee.Func.Object() == nil {
            return fmt.Errorf("nil object callee : %v\n", edge.Callee.Func)
        }

        callees, ok := calleesFacts[callerFunc.Object()]
        if !ok {
            callees = make(calleesFact)
            calleesFacts[callerFunc.Object()] = callees
        }
        callees[edge.Callee.Func.Object()] = struct{}{}

        return nil
    }); err != nil {
        return nil, err
    }

    // Object Factに保存
    var callers []types.Object
    for caller, callees := range calleesFacts {
        caller, callees := caller, callees

        callers = append(callers, caller)
        pass.ExportObjectFact(caller, &callees)
    }

    // #### 対象フィールドを参照している関数を収集 ####
    // AST解析
    refersFacts := make(map[types.Object]refersFact)
    var topLevelFunc types.Object
    inspect.Nodes([]ast.Node{
        (*ast.File)(nil),
        (*ast.FuncDecl)(nil),
        (*ast.Ident)(nil),
    }, func(n ast.Node, push bool) bool {

        switch n := n.(type) {
        case *ast.File:
            f := pass.Fset.File(n.Pos())
            return !strings.HasSuffix(f.Name(), "_test.go") && strings.Index(f.Name(), "testdata") < 0
        case *ast.FuncDecl:
            if push {
                topLevelFunc = pass.TypesInfo.ObjectOf(n.Name)
            } else {
                topLevelFunc = nil
            }
            return true
        case *ast.Ident:
            if !push {
                return false
            }
            if topLevelFunc == nil {
                return false
            }
            o := pass.TypesInfo.ObjectOf(n)

            if vr, ok := o.(*types.Var); ok {
                if _, ok := targetFields[vr]; ok {
                    refers, ok := refersFacts[topLevelFunc]
                    if !ok {
                        refers = make(refersFact)
                        refersFacts[topLevelFunc] = refers
                    }
                    refers[o] = struct{}{}
                }
            }

            return true
        default:
            return true
        }
    })

    // Object Factに保存
    for f, v := range refersFacts {
        f, v := f, v
        pass.ExportObjectFact(f, &v)
    }

    // #### 結果出力 ####
    if pass.Pkg.Path() == apiPkg {
        report := func(fs []types.Object, field types.Object) error {
            fmt.Print(field.String())

            for _, f := range fs {
                fmt.Printf("\t%s", f.String())
            }

            fmt.Println()

            return nil
        }

        var visitFunc func(fs []types.Object, f types.Object) error
        visitFunc = func(fs []types.Object, f types.Object) error {
            for _, ancestor := range fs {
                if ancestor == f {
                    // cancel infinite recursion
                    return nil
                }
            }
            fs = append(fs, f)

            var rf refersFact
            var cf calleesFact

            if f.Pkg().Path() == apiPkg {
                rf = refersFacts[f]
                cf = calleesFacts[f]
            } else {
                pass.ImportObjectFact(f, &rf)
                pass.ImportObjectFact(f, &cf)
            }

            for field := range rf {
                if err := report(fs, field); err != nil {
                    return err
                }
            }

            for callee := range cf {
                if err := visitFunc(fs, callee); err != nil {
                    return err
                }
            }

            return nil
        }

        // exportedな関数のみをマージする
        funcSet := make(map[types.Object]struct{})

        for referrer := range refersFacts {
            if !referrer.Exported() {
                continue
            }
            funcSet[referrer] = struct{}{}
        }

        for caller := range calleesFacts {
            if !caller.Exported() {
                continue
            }
            funcSet[caller] = struct{}{}
        }

        // 出力!!
        var fs []types.Object
        for caller := range funcSet {
            if err := visitFunc(fs, caller); err != nil {
                return nil, err
            }
        }
    }

    return nil, nil
}

func collectFields(st *types.Struct, stSet map[*types.Struct]struct{}) (vars []*types.Var) {
    for i := 0; i < st.NumFields(); i++ {
        sf := st.Field(i)

        sft := sf.Type()
        sfut := sft.Underlying()

        if cst, ok := sfut.(*types.Struct); ok {
            if _, ok := stSet[cst]; !ok {
                stSet[cst] = struct{}{}
                sfs := collectFields(cst, stSet)
                for _, sf := range sfs {
                    vars = append(vars, sf)
                }
            }
        }

        vars = append(vars, sf)
    }
    return vars
}

func topLevelFunc(f *ssa.Function) *ssa.Function {
    for {
        if p := f.Parent(); p == nil {
            return f
        } else {
            f = p
        }
    }
}

ビルド

$ cd somewhere/sfused
$ go install sfused/cmd/sfused

実行

起点となるパッケージを指定して実行します。

$ cd somewhere/sample
$ sfused sample/api

結果

フィールド\t関数1\t関数2\t関数3\t... の様にタブ区切り形式で出力しています。

エクセルに貼り付けた結果↓

スクリーンショット 2021-12-16 15.15.21.png

まぁまぁよさげ(^^)

所感

手作業でやると気が遠くなるフィールド数ですが、静的解析を使えばサクッと自動で出来ちゃいます!!

って言いたくて始めましたが地味にハマったりして結構時間かけてしまいました。
それでも手作業よりはかなり速いかな・・・

汎用的なツールだけでなく、こういったプロジェクトに特化したちょっとした手作業を省略する為のツールを作るのにも静的解析は有用だと思いました。

TODO

まだまだやることいっぱい。

  • 重複結果の排除
  • 構造体の入れ子や埋め込みの対応
  • 対象構造体名やパッケージをフラグで渡せる様に
  • 出力形式をもっと使いやすく(ソートとか)
  • 参照されていないフィールドの抽出
  • apiパッケージからは辿れない処理で参照されているフィールドの抽出
  • golang.org/x/toolsのバージョンを上げる
  • 汎用ツールにする(^ω^;
  • テスト(^ω^;

参考資料(大感謝🙏)

3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3