Edited at

Trampoline で再帰処理の最適化

More than 3 years have passed since last update.

Scala ではコンパイラや JavaVM では最適化されない末尾相互再帰を Trampoline (トランポリン) を使って解決します。また末尾再帰になっていない再帰処理をコールスタックを消費しないで実行するためにも Trampoline を使用します。

C/C++/Java 暦が長いんで末尾最適化されずコールスタック量も想定できない処理は自前でループに書き換えていましたけど、後述する問題を Trampoline で解決できないかなと知識まとめがてらの投稿です。


前提知識: Scala の末尾相互再帰

関数やメソッドの再帰呼び出しはその再帰が処理の一番最後であれば理論的にループ命令に展開できます。これは再帰ごとにその呼び出しのコールスタックをクリアしても問題ないためで、ループの方がコールスタックを消費しないので再帰がどれだけ深くなっても一定量のスタック空間で済む (Stack Overflow が起きない) という利点があります。

ざっくりコードで示せば以下のような変換を Scala コンパイラが自動で行ってくれるという意味です。

def A():X = {

何かの処理
if(終了条件) 結果 else A()
}
// ↓ 以下のようなループ処理にコンパイラが変換
def A():X = {
do 何かの処理 while(終了条件)
結果
}

この末尾最適化は通常コンパイラやインタープリタによって行われます。Scala コンパイラは再帰が末尾最適化可能であればその再帰処理をループ命令に展開します。@tailrec はうっかり末尾最適化になっていないコードを書いてしまったときにコンパイルエラーとしてくれる便利なアノテーションですので、末尾最適化を意図している再帰には @tailrec を必ず付けるべし。

さて、Scala コンパイラも万能というわけではなく、以下のような末尾相互再帰も理論的にはループに変換が可能だが最適化は行いません。

def A():X = if(終了条件) 結果 else B()

def B():X = if(終了条件) 結果 else A()

Scala では Trampoline を実装した scala.util.control.TailCalls が用意されていて、末尾相互再帰のケースではこれを使ってスタック消費の問題を解決します。


末尾相互再帰を Trampoline で解決する

TailCalls API リファレンスのサンプルにあるように数値が偶数かどうかを判定する関数を Trampoline を使わない末尾相互再帰で書いてみよう。

def isEven(value:Long):Boolean = {

def _isEven(x:Long):Boolean = if(x == 0) true else _isOdd(x - 1)
def _isOdd(x:Long):Boolean = if(x == 0) false else _isEven(x - 1)
_isEven(value)
}

上記をある大きな数に対して実行すると想定通り Stack Overflow が発生。末尾相互再帰なので Scala コンパイラが最適化されていないためだ。

scala> isEven(100000L)

java.lang.StackOverflowError
at ._isOdd$1(<console>:12)
at ._isEven$1(<console>:11)
at ._isOdd$1(<console>:12)
...

ではこれを Trampoline を使って書き直してみよう。

import scala.util.control.TailCalls._

def isEven(value:Long):Boolean = {
def _isEven(x:Long):TailRec[Boolean] = if(x == 0) done(true) else tailcall{ _isOdd(x - 1) }
def _isOdd(x:Long):TailRec[Boolean] = if(x == 0) done(false) else tailcall{ _isEven(x - 1) }
_isEven(value).result
}

終了条件が確定したところで done(結果) で終了する。そうでなければ tailcall{再帰呼び出し} で再帰呼び出しを行えばよい。

scala> isEven(100000L)

res21: Boolean = true
scala> isEven(9999999999L)
res22: Boolean = false

実行時間が長くなるのでこれ以上は確認できないが同じ条件でも Stack Overflow は発生しなくなっている。


末尾最適化できない再帰を Trampoline で解決する

再帰呼び出しを使用して階乗計算を実装してみます。この処理は再帰部分が末尾ではない (再帰後に $n \times $ の処理が存在する) ためコンパイラが最適化ができません。

def factorial(n:Int):BigInt = if(n < 2) 1 else n * factorial(n - 1)

実際 @tailrec を付けるとコンパイルエラー error: could not optimize @tailrec annotated method factorial: it contains a recursive call not in tail position が発生します。そして末尾最適化が行われなければ当然どこかで StackOverflowError が発生します。

scala> factorial(10000)

java.lang.StackOverflowError
at scala.math.BigInt$.int2bigInt(BigInt.scala:97)
at .factorial(<console>:19)
...

このように、再帰の結果で処理を行う場合も Trampoline を使って Stack Overflow を回避することができます。

import scala.util.control.TailCalls._

def factorial(n:Int):TailRec[BigInt] = if(n < 2) done(1) else {
tailcall{ factorial(n - 1) }.map{ x => n * x }
}

TailRecresult を参照した時点で再帰処理を実行する遅延評価型の動作をします。このときに map{} を使用して再帰呼び出しの結果に対する処理を記述することができます。

scala> factorial(100000).result

res57: BigInt = 28242294079603478...

TailRecmap, flatMap が実装されているため for yield でも使用することができます。

def factorial(n:Int):TailRec[BigInt] = if(n < 2) done(1) else for{

x <- tailcall{ factorial(n - 1) }
} yield (n * x)

TailCalls の API リファレンスにあるフィボナッチ数の例のように2つ以上の TailRec を扱う場合は flatMap でたたみ込むより for yield を使った方がスッキリ書けると思います。


おまけ: 非同期でのリトライ処理を Trampoline で解決… できる?

さて、もともと Trampoline を使って解決したかった問題がこれ。結論から言うと挫折。

「非同期処理」と「少し間を置いて再帰で再実行」

イメージとしては通信エラーが発生したときに再帰で再実行するだとか、Web クローラーボットのように前回の収集結果を使って次の収集を繰り返すような処理ね。

分散システムでよく使われるこのパターンが再帰とはあまり相性がよろしくない。非同期処理内で再帰を行おうとすると末尾再帰ではなくなるので (スタックではなく) ヒープを消費してしまう。

import scala.concurrent.{Await,Future,Promise}

import scala.concurrent.duration.Duration
import scala.concurrent.ExecutionContext.Implicits.global
import scala.util.{Try,Success,Failure}

// 一定時間後に指定された処理を実行する
val timer = new java.util.Timer()
def at[T](delay:Long)(f: =>T):Future[T] = {
val promise = Promise[T]()
timer.schedule(new java.util.TimerTask(){
override def run() = Try{ f } match {
case Success(x) => promise.success(x)
case Failure(ex) => promise.failure(ex)
}
}, delay)
promise.future
}

// 何かの処理
def retrieve(count:Int):String = {
if(count <= 0) "SUCCESS!" else throw new Throwable()
}

// 何かの処理が成功するまで繰り返す処理
def retryLoop(wait:Long, count:Int):Future[String] = {
at(wait){ retrieve(count) }.recoverWith{ case _ =>
retryLoop(wait, count - 1)
}
}
Await.result(retryLoop(0, 1000000), Duration.Inf)
timer.cancel()

at(){...} のラムダは非同期で実行されブロック外とスタックを共有していないため、retryLoop() の再帰回数が多くなると Stack Overflow ではなく OutOfMemoryError が発生します。

scala> Await.result(retryLoop(0, 1000000), Duration.Inf)

java.lang.OutOfMemoryError: GC overhead limit exceeded
at scala.collection.immutable.List.$colon$colon(List.scala:111)
at scala.concurrent.impl.Promise$DefaultPromise.scala$concurrent$impl$Promise$DefaultPromise$$dispatchOrAddCallback(Promise.scala:282)
...

この retryLoop() だけを Trampoline に書き換えてヒープを消費しないようにできないかと試行錯誤したがどうしたものか?

import scala.util.control.TailCalls._

def retryLoop(wait:Long, count:Int):Future[TailRec[String]] = {
val future = at(wait){ retrieve(count) }
???
}

再帰のたびに Future と TailRec が交互に重なるので単純に flatMap で繋げられないというか。

ヒープを消費する理由は Future を合成するために再帰回数分のインスタンスがメモリ上に乗っているのかなとアタリを付けて、再帰だけど Future を合成しないパターンに変更。

def retryLoop(wait:Long, count:Int):Future[String] = {

val promise = Promise[String]()
def _retryLoop(count:Int):Unit = at(wait){
Try{ retrieve(count) } match {
case Success(x) => promise.success(x)
case Failure(ex) => _retryLoop(count - 1)
}
}
_retryLoop(count)
promise.future
}

これなら OutOfMemoryError も発生せずに成功する。まぁこれで逃げておくか。

scala> Await.result(retryLoop(0, 1000000), Duration.Inf)

res3: String = SUCCESS!