LoginSignup
9
4

More than 3 years have passed since last update.

末尾再帰最適化マクロを作ってみた

Last updated at Posted at 2020-12-21

はじめに

Julia自体は末尾再帰最適化を行いませんが、簡単なマクロで末尾再帰を消去できたので、紹介します。

作ってみたと言いながら、実は五年前くらいに作ってgithubにおいたまま放置していたのですが、最近動かないという問い合わせがきたりしたので、最新バージョンのJuliaに対応させたついでに、マクロの使い方を復習してみたところです。ちょうど今年のAdvent Calendarにはマクロの話題はないみたいなので、投稿することにしました。

使ってみる

まずは、作成したパッケージのインストール、使い方を説明します。

インストール

野良パッケージとしてgithubにおいてあるので、URLを指定してインストールできます。

julia> ]
(@v1.5) pkg> add https://github.com/TakekazuKATO/TailRec.jl
    Cloning git-repo `http://github.com/TakekazuKATO/TailRec.jl`
   Updating git-repo `http://github.com/TakekazuKATO/TailRec.jl`
   Updating registry at `~/.julia/registries/General`
   Updating git-repo `https://github.com/JuliaRegistries/General.git`
  Resolving package versions...
Updating `~/.julia/environments/v1.5/Project.toml`
  [f6209947] + TailRec v0.2.0 `http://github.com/TakekazuKATO/TailRec.jl#master`
Updating `~/.julia/environments/v1.5/Manifest.toml`
  [f6209947] + TailRec v0.2.0 `http://github.com/TakekazuKATO/TailRec.jl#master`
(@v1.5) pkg>

使用例

単純な末尾再帰の例

まずは、簡単な末尾再帰関数を考えてみます。以下は再起にする意味はないけど1からNまでの総和を求める関数です。

julia> function sumR(x, i=0)
       if x == 0
         i
       else
         sumR(x-1, i+x)
       end
       end
sumR (generic function with 2 methods)

julia> sumR(10)
55

julia> sumR(1000000)
ERROR: StackOverflowError:
Stacktrace:
 [1] sumR(::Int64, ::Int64) at ./REPL[6]:5 (repeats 79984 times)

SumR(x,0)として呼び出すと、第一引数を1ずつ減らしながら第二引数にそれまでのxを足した値を与えながら再帰呼び出しを繰り返し、第一引数が0になるとその時の第二引数(第一引数の総和)を返します。これで総和が計算できますが、2回目に1000000を第一引数にして呼び出したところでスタックオーバーフローになっているのがわかります。これは再帰が深すぎて関数呼び出しスタックがオーバーフローしたためです。

TailRec.jlを使った最適化

さきほどの関数sumRを定義の先頭に@tailrecマクロをつけて全く同じようにsumTCOという名前で定義します。

using TailRec

julia> @tailrec function sumTCO(x, i=0)
       if x == 0
         i
       else
         sumTCO(x-1, i+x)
       end
       end
sumTCO (generic function with 2 methods)

julia> sumTCO(10)
55

julia> sumTCO(1000000)
500000500000

今度は1000000でもちゃんと計算できています。

末尾再帰と末尾再帰最適化

末尾再帰とは、再帰呼び出し関数のうち、再帰呼び出しが関数の一番最後に実行される関数をいいます。ここで一番最後というのは文面上の最後ではなく、実行手順の一番最後になります。なので、再帰呼び出し部分が ``1+sumTCO(x-1,i)''のようになっていると、再起呼び出しの後に1+の計算をするので、末尾再帰になりません。

ここで、関数呼び出しでシステム上なにが行われるか考えてみると、今の状態や関数呼び出し位置をメモリスタックに積んで、関数定義にジャンプします。関数が終わればスタックから呼び出し位置と状態をポップして、元の位置に戻ります。

このとき、ある関数の中に再帰呼び出しがあって、かつこの再帰呼び出しが関数の実行手順の一番最後(つまり末尾再帰)だったとすると、再帰呼び出しから復帰した後に、関数は終了するため、呼び出し前の状態や復帰位置は不要でスタックに積む必要はなく、引数を新しい呼び出しに更新して関数冒頭にジャンプすればがいいことになります。このように末尾再帰をジャンプやループに書き換えることを末尾再帰最適化(Tail Recursive Call Optimization)、あるいは、末尾再帰消去(Tail Recursive Call Elimination)と呼びます。

@tailrecマクロは、末尾再帰呼び出しをジャンプに置き換えるように関数定義を書き換えるマクロです。実際にどのような書き換えが行われたか、マクロを適用した結果を展開する@macroexpandマクロを使って確認してみます。

julia> @macroexpand@tailrec function sumTCO(x, i=0)
       if x == 0
         i
       else
         sumTCO(x-1, i+x)
       end
       end
:(function sumTCO(x, i = 0)
      $(Expr(:symboliclabel, :retry))
      begin
          #= REPL[30]:1 =#
          #= REPL[30]:2 =#
          if x == 0
              #= REPL[30]:3 =#
              i
          else
              #= REPL[30]:5 =#
              begin
                  (x, i) = (x - 1, i + x)
                  $(Expr(:symbolicgoto, :retry))
              end
          end
      end
  end)

:(functionからが@macroexpandの出力結果です。関数定義の冒頭に $(Expr(:symboliclabel, :retry))が加わり、再帰呼び出し部分が

(x, i) = (x - 1, i + x)
$(Expr(:symbolicgoto, :retry))

に変わっています。ちょっと読みにくいので、これを普通の関数定義の書き方で書けば次のようになります。

function sumTCO(x, i=0)
       @label retry
       if x == 0
         i
       else
         (x, i) = (x-1, i+x)
         @goto retry
       end
end

つまり、再帰呼び出し部分をgotoで関数冒頭にジャンプするように変換しています。このように、末尾再帰になっている部分を差し替えることで再帰呼び出しを消去し、最適化します。

TailRec.jlの仕組み

TailRec.jlの仕組みをみていきます。ソースコードを下記に挙げます。

module TailRec
export @tailrec

macro tailrec(func)
    fargs=map(e->if isa(e,Expr) e.args[1] else e end,func.args[1].args)
    fbody=func.args[2]
    fbody=rewrite(fbody,fargs)
    func.args[2]=Expr(:block,:(@label retry),fbody)
    esc(func)
end

function rewrite(expr,args,callflag=false)
    if !isa(expr,Expr)
        expr
    elseif expr.head == :call && expr.args[1] == args[1]
        if callflag
            @warn "Not tail recursive call is found."
            expr
        else
            newargs=Expr(:tuple)
            newargs.args=args[2:end]
            olda nrgs=Expr(:tuple)
            oldargs.args=expr.args[2:end]
            Expr(:block, Expr(:(=),newargs,oldargs), :(@goto retry) )
        end
    elseif expr.head == :block
        expr.args[end]=rewrite(expr.args[end],args, expr.head==:call || callflag)
        expr
    else
        expr.args = map(a->rewrite(a,args, expr.head==:call ||callflag),expr.args)
        expr
    end
end
end

macro tailrec(func) というように関数定義のfunctionの代わりにmacroと書くとマクロを定義できます。引数funcには最適化したい関数の定義文のAST(abstract syntax tree: 抽象構文木)が与えられます。ASTというのはjuliaのプログラムを構文解析した結果です。dump()関数でASTの中身を表示できるので、簡単なマクロで確認してみましょう。

julia> macro mtest(func)
       dump(func)
       end
@mtest (macro with 1 method)

julia> @mtest f(x)=x
Expr
  head: Symbol =
  args: Array{Any}((2,))
    1: Expr
      head: Symbol call
      args: Array{Any}((2,))
        1: Symbol f
        2: Symbol x
    2: Expr
      head: Symbol block
      args: Array{Any}((2,))
        1: LineNumberNode
          line: Int64 1
          file: Symbol REPL[2]
        2: Symbol x

juliaではASTは、Exprという型で木構造を表現していて、Exprはheadとargsという属性を持っています。Expr型は木構造のノードで、headにノードのタイプ、argsに子ノードが与えられます。juliaのマクロは、この木構造を書き換えることでプログラムを書き換えます。

末尾再帰関数の場合のASTをみてみます。

julia> @mtest function sumTCO(x, i=0)
              if x == 0
                i
              else
                sumTCO(x-1, i+x)
              end
              end
Expr
  head: Symbol function
  args: Array{Any}((2,))
    1: Expr
      head: Symbol call
      args: Array{Any}((3,))
        1: Symbol sumR
        2: Symbol x
        3: Expr
          head: Symbol kw
          args: Array{Any}((2,))
            1: Symbol i
            2: Int64 0
    2: Expr
      head: Symbol block
      args: Array{Any}((3,))
        1: LineNumberNode
          line: Int64 1
          file: Symbol REPL[11]
        2: LineNumberNode
          line: Int64 2
          file: Symbol REPL[11]
        3: Expr
          head: Symbol if
          args: Array{Any}((3,))
            1: Expr
              head: Symbol call
              args: Array{Any}((3,))
                1: Symbol ==
                2: Symbol x
                3: Int64 0
            2: Expr
              head: Symbol block
              args: Array{Any}((2,))
                1: LineNumberNode
                2: Symbol i
            3: Expr
              head: Symbol block
              args: Array{Any}((2,))
                1: LineNumberNode
                2: Expr

この木構造から、末尾再帰の部分をみつけて@gotoに置き換えることができればいいということになります。

ではもう一度@tailrecマクロの定義をみてみます。

macro tailrec(func)
    fargs=map(e->if isa(e,Expr) e.args[1] else e end,func.args[1].args)
    fbody=func.args[2]
    fbody=rewrite(fbody,fargs)
    func.args[2]=Expr(:block,:(@label retry),fbody)
    esc(func)
end

funcは@tailrecに引数で与えた関数定義のASTです。関数定義は、関数名と仮引数、関数本体のノードを持ちます。まずmapで関数名と仮引数のリストから、デフォルト値の設定を消去しています。その次の行ではrewrite関数(この後に定義)で関数本体を書き換えています。最後に関数本体の最初に@labelを挟み込んでいます。

このように、Expr()で新しい木構造をつくり、もとの関数本体のノード(func.args[2])と差し替えることでASTを書き換えることができます。最後にescで書き換えたASTを関数定義そのものに戻しています。

次に、書き換えの本体であるrewrite関数をみてみます。

# 引数のexprは現在調べているノード、argsは関数名と仮引数のリスト、callflagは探索経路で関数呼び出しの引数かどうか
function rewrite(expr,args,callflag=false)
    if !isa(expr,Expr) # ノードが式ではない場合
        expr # ノードを書き換えずにそのまま返す
    elseif expr.head == :call && expr.args[1] == args[1] # 再帰呼び出しの場合
        if callflag # すでに再帰の後に関数呼び出しがある。末尾再帰にならないので警告
            @warn "Not tail recursive call is found."
            expr
        else # gotoでジャンプする前に仮引数を、再帰呼び出しの引数に更新する
            newargs=Expr(:tuple)
            newargs.args=args[2:end]
            olda nrgs=Expr(:tuple)
            oldargs.args=expr.args[2:end]
            # 再帰をgotoに書き換える
            Expr(:block, Expr(:(=),newargs,oldargs), :(@goto retry) )
        end
    elseif expr.head == :block # ノードがブロックの場合
        # 末尾再帰の場合は再帰はブロックの最後にあるので、最後の子ノードを探索
        expr.args[end]=rewrite(expr.args[end],args, expr.head==:call || callflag)
        expr
    else # それ以外の場合は、if,forなどの構文か関数呼び出し(演算子含む)
        # 子ノードを全て探索し書き換え。このとき、
           # 関数呼び出し場合(expr.head==:call)は、子ノードの後に関数呼び出しをするため、末尾再帰でないのでフラグを立てておく
           # それ以外の構文の場合は末尾再帰の可能性があるのでフラグを立てない
        expr.args = map(a->rewrite(a,args, expr.head==:call ||callflag),expr.args)
        expr
    end
end

この関数はASTの木構造を再起的にたどりながら再帰呼び出しを探し出して書き換えます。
こまかい判定などは上記プログラムにコメントで書き加えていますが、全体的な動作としては、木構造を探索しながら再帰部分を探していきますが、このときに再帰の後に、何か実行する部分があれば末尾再帰ではないので、警告を出すようにしています。末尾再帰を見つけたら、引数の更新をおこなってから再帰を@gotoによるジャンプに書き換えます。

おわりに

juliaのマクロの書き方自体は、公式のドキュメントや、日本語の解説などがありますが、基礎的なことが書いてあって、実際にASTをいじってプログラムを書き換える様子がなかなかピンときません。一方でさまざまなマクロを含むパッケージが公開されていますが、複雑で全体的な動きを追いかけるのが難しいです。TailRec.jlはコンパクトで、かつ、実際的なので例としてちょうどいいと思い紹介しました。

なお、Juliaが末尾再帰最適化に対応していない理由は、Juliaの場合普通にループをかけて十分に読みやすく効率的なので、あまり最適化が有効な場面が少なくあまりメリットがないということのようです。なので、このマクロが有用な場面はごくごく限られると思います。

9
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
4