はじめに
「再帰関数は再帰呼び出しの代わりにスタックを用いて書き直すことができる」という話がありますが、その具体的な方法の1つについて説明を試みます。
原理
プログラミング言語処理系における、関数呼び出しの実装方法として、以下のようなものがあります。
- 関数呼び出しの際は、現在のローカル変数と、関数呼び出しから戻った際の再開位置を、スタックに積む。新たにローカル変数を確保して、関数の引数を代入したうえで、該当関数の先頭にジャンプする
- 関数から戻る際は、ローカル変数と再開位置をスタックから取り出して、再開位置にジャンプする
つまり、スタックがあってジャンプができれば、あらゆる関数呼び出し(再帰呼び出しを含む)を表現できます。ジャンプを無限ループ+条件分岐で表現すれば、ループに変換できたことになります。
実践
こちらの再帰関数をループに変換しましょう。(言語はPythonを使用)
def fib(n):
if n < 2:
return n
else:
return fib(n-1) + fib(n-2)
まず、再帰呼び出しの結果を必ず変数に代入するようにしたうえで、関数の先頭と再帰呼び出し行に連番をつけましょう。
def fib(n):
# 0
if n < 2:
return n
else:
ret1 = fib(n-1) # 1
ret2 = fib(n-2) # 2
return ret1 + ret2
あとは、以下のコードの...
の部分を埋めていきます。(case
文は、連番の数だけ用意)
def fib_loop(n):
stack = [(0, {'n': n})]
while stack:
match stack.pop()
case 0, env:
...
case 1, env:
...
case 2, env:
...
return ret
使われている変数の意味は以下の通りです。
変数 | 意味 |
---|---|
env | 関数のローカル変数の辞書 |
stack | 位置とenv を保存するスタック |
ret | 戻り値 |
変換のルールは以下の通りです。
- 該当の連番の位置から、
return
もしくは別の連番の位置までが変換対象となります - 関数内で使われている変数は、
env
内の値を使う形に書き換えます
(n
→env['n']
) - 関数の先頭で、位置0と引数を登録した辞書をスタックに積みます
(stack = [(0, {'n': n})]
) -
return xxx
は、以下のように書き換えます
ret = xxx
continue
- 再帰呼び出し
fib(xxx)
の存在する行は、以下のように書き換えます(連番をN
とします)
stack += (N, env), (0, {'n': xxx})
continue
さらに、新たなcase
の分岐を追加して、それ以降のコードは、この分岐の下に追加します。
case N, env:
env['retN'] = ret
最終的には、以下のようになります。
def fib_loop(n):
stack = [(0, {'n': n})]
while stack:
match stack.pop():
case 0, env:
if env['n'] < 2:
ret = env['n']
continue
else:
stack += (1, env), (0, {'n': env['n'] - 1})
continue
case 1, env:
env['ret1'] = ret
stack += (2, env), (0, {'n': env['n'] - 2})
continue
case 2, env:
env['ret2'] = ret
ret = env['ret1'] + env['ret2']
continue
return ret
注意
- 位置と
env
の組だけをスタックに積めばgoto相当になるので、頑張れば再帰関数内にループがある場合も再現できると思います - この方法では対応不可な場合がいつもあると思いますが(
try
、yield
など)、ネタということでお許しください
最後に
「再帰関数は再帰呼び出しの代わりにスタックを用いて書き直すことができる」のは事実ではありますが、ご覧の通り動作を追うのが厳しくなるので、素直に再帰を使うことをお勧めします。
余談
ローカル変数と再開位置の管理をジェネレータに任せることもできます。見た目だけは再帰と変わりませんが、動作を追うのはさらに厳しくなります。
def fib_gen(n):
def fib(n):
if n < 2:
return n
else:
return (yield n-1) + (yield n-2)
stack, ret = [fib(n)], None
while stack:
try:
stack.append(fib(stack[-1].send(ret)))
ret = None
except StopIteration as e:
stack.pop()
ret = e.value
return ret
ソースコード
やり方が分からないという人向けに、同じ変換を行うpythonコードを載せておきます。コード末尾の文字列を、目的の再帰関数に書き換えて、試してみてください。実装が中途半端で、変換できない場合が多々あると思いますので、ご注意ください。
import ast
class Rec2Loop(ast.NodeTransformer):
def visit_FunctionDef(self, f):
self.v = local_vars(f)
self.f = f
self.n = 0
org = f.body
f.body = ast.parse('''
stack = [(0, {})]
while stack:
match stack.pop():
case 0, env:
pass
return ret''').body
init_env(f, f.body[0].value.elts[0].elts[1], [ast.Name(a.arg, ast.Load()) for a in f.args.args])
f.body[1].body[0].cases[0].body = self.visits(org)
return ast.fix_missing_locations(f)
def visit_If(self, i):
i.test = self.visit(i.test)
(tgt := self.tgt).append(i)
i.body = self.visits(i.body)
i.orelse = self.visits(i.orelse)
self.tgt = tgt
def visit_Return(self, r):
return [ast.Assign([ast.Name('ret', ast.Store())], self.visit(r.value)), ast.Continue()]
def visit_Call(self, c):
if not isinstance(c.func, ast.Name) or c.func.id != self.f.name:
return self.generic_visit(c)
a = ast.parse('''
stack += (None, env), (0, {})
continue''').body
init_env(self.f, a[0].value.elts[1].elts[1], [self.visit(a) for a in c.args])
self.n += 1
a[0].value.elts[0].elts[0].value = self.n
self.tgt += a
self.f.body[1].body[0].cases.append(ast.parse(f'''
match _:
case {self.n}, env:
env['ret{self.n}'] = ret''').body[0].cases[0])
self.tgt = self.f.body[1].body[0].cases[-1].body
return create_ref(f'ret{self.n}', ast.Load())
def visit_Name(self, n):
if n.id not in self.v:
return n
return create_ref(n.id, n.ctx)
def visits(self, l):
tgt = self.tgt = []
for e in l:
v = self.visit(e)
if v is None:
continue
elif isinstance(v, ast.AST):
v = [v]
self.tgt += v
return tgt
def local_vars(t):
r = set()
for n in ast.walk(t):
if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Store):
r.add(n.id)
elif isinstance(n, ast.arg):
r.add(n.arg)
return r
def init_env(f, e, v):
e.keys = [ast.Constant(a.arg) for a in f.args.args]
e.values = v
def create_ref(name, ctx):
return ast.Subscript(ast.Name('env', ast.Load()), ast.Constant(name), ctx)
print(ast.unparse(Rec2Loop().visit(ast.parse('''
def fib(n):
if n<2:
return n
else:
return fib(n-1)+fib(n-2)
'''))))