1. torao@github

    Posted

    torao@github
Changes in title
+Trampoline で再帰処理の最適化
Changes in tags
Changes in body
Source | HTML | Preview
@@ -0,0 +1,208 @@
+Scala ではコンパイラや JavaVM では最適化されない末尾相互再帰を **Trampoline** (トランポリン) を使って解決します。また末尾再帰になっていない再帰処理をコールスタックを消費しないで実行するためにも Trampoline を使用します。
+
+## Scala の末尾相互再帰
+
+関数やメソッドの再帰呼び出しは、その再帰が処理の一番最後であれば理論的にループ命令に展開できます。これは再帰ごとにその呼び出しのコールスタックをクリアしても問題ないためで、ループの方がコールスタックを消費しないので再帰がどれだけ深くなっても一定量のスタック空間で済む (Stack Overflow が起きない) という利点があります。
+
+ざっくりコードで示せば以下のような変換を Scala コンパイラが自動で行ってくれる。
+
+```scala
+def A():X = {
+ 何かの処理
+ if(終了条件) 結果 else A()
+}
+// ↓ 以下のようなループ処理にコンパイラが変換
+def A():X = {
+ do 何かの処理 while(終了条件)
+ 結果
+}
+```
+
+この**末尾最適化**は通常コンパイラやインタープリタによって行われます。Scala コンパイラは再帰が末尾最適化可能であればその再帰処理をループ命令に展開します。`@tailrec` はうっかり末尾最適化になっていないコードを書いてしまったときにコンパイルエラーとしてくれる便利なアノテーションですので、末尾最適化を意図している再帰には `@tailrec` を必ず付けるべし。
+
+さて、Scala コンパイラも万能というわけではなく、以下のような**末尾相互再帰**も理論的にはループに変換が可能だが最適化は行いません。
+
+```scala
+def A():X = if(終了条件) 結果 else B()
+def B():X = if(終了条件) 結果 else A()
+```
+
+Scala では Trampoline を実装した `scala.util.control.TailCalls` が用意されていて、末尾相互再帰のケースではこれを使ってスタック消費の問題を解決します。
+
+## 末尾相互再帰を Trampoline で解決する
+
+[`TailCalls` API リファレンス](http://www.scala-lang.org/api/current/scala/util/control/TailCalls$.html)のサンプルにあるように数値が偶数かどうかを判定する関数を Trampoline を使わない末尾相互再帰で書いてみよう。
+
+```scala
+def isEven(value:Long):Boolean = {
+ def _isEven(x:Long):Boolean = if(x == 0) true else _isOdd(x - 1)
+ def _isOdd(x:Long):Boolean = if(x == 0) false else _isEven(x - 1)
+ _isEven(value)
+}
+```
+
+上記をある大きな数に対して実行すると想定通り Stack Overflow が発生。末尾相互再帰なので Scala コンパイラが最適化されていないためだ。
+
+```
+scala> isEven(100000L)
+java.lang.StackOverflowError
+ at ._isOdd$1(<console>:12)
+ at ._isEven$1(<console>:11)
+ at ._isOdd$1(<console>:12)
+ ...
+```
+
+ではこれを Trampoline を使って書き直してみよう。
+
+```scala
+import scala.util.control.TailCalls._
+def isEven(value:Long):Boolean = {
+ def _isEven(x:Long):TailRec[Boolean] = if(x == 0) done(true) else tailcall{ _isOdd(x - 1) }
+ def _isOdd(x:Long):TailRec[Boolean] = if(x == 0) done(false) else tailcall{ _isEven(x - 1) }
+ _isEven(value).result
+}
+```
+
+終了条件が確定したところで `done(結果)` で終了する。そうでなければ `tailcall{再帰呼び出し}` で再帰呼び出しを行えばよい。
+
+```
+scala> isEven(100000L)
+res21: Boolean = true
+scala> isEven(9999999999L)
+res22: Boolean = false
+```
+
+実行時間が長くなるのでこれ以上は確認できないが同じ条件でも Stack Overflow は発生しなくなっている。
+
+## 末尾最適化できない再帰を Trampoline で解決する
+
+再帰呼び出しを使用して階乗計算を実装してみます。この処理は再帰部分が末尾ではない (再帰後に $n \times $ の処理が存在する) ためコンパイラが最適化ができません。
+
+```scala
+def factorial(n:Int):BigInt = if(n < 2) 1 else n * factorial(n - 1)
+```
+
+実際 `@tailrec` を付けるとコンパイルエラー `error: could not optimize @tailrec annotated method factorial: it contains a recursive call not in tail position` が発生します。そして末尾最適化が行われなければ当然どこかで `StackOverflowError` が発生します。
+
+```
+scala> factorial(10000)
+java.lang.StackOverflowError
+ at scala.math.BigInt$.int2bigInt(BigInt.scala:97)
+ at .factorial(<console>:19)
+ ...
+```
+
+このように、再帰の結果で処理を行う場合も Trampoline を使って Stack Overflow を回避することができます。
+
+```scala
+import scala.util.control.TailCalls._
+def factorial(n:Int):TailRec[BigInt] = if(n < 2) done(1) else {
+ tailcall{ factorial(n - 1) }.map{ x => n * x }
+}
+```
+
+`TailRec` は `result` を参照した時点で再帰処理を実行する遅延評価型の動作をします。このときに `map{}` を使用して再帰呼び出しの結果に対する処理を記述することができます。
+
+```
+scala> factorial(100000).result
+res57: BigInt = 28242294079603478...
+```
+
+`TailRec` は `map`, `flatMap` が実装されているため `for yield` でも使用することができます。
+
+```scala
+def factorial(n:Int):TailRec[BigInt] = if(n < 2) done(1) else for{
+ x <- tailcall{ factorial(n - 1) }
+} yield (n * x)
+```
+
+`TailCalls` の API リファレンスにあるフィボナッチ数の例のように2つ以上の `TailRec` を扱う場合は `flatMap` でたたみ込むより `for yield` を使った方がスッキリ書けると思います。
+
+## おまけ: 非同期でのリトライ処理を Trampoline で解決… できる?
+
+さて、もともと Trampoline を使って解決したかった問題がこれ。結論から言うと挫折。
+
+「非同期処理」と「少し間を置いて再帰で再実行」
+
+イメージとしては通信エラーが発生したときに再帰で再実行するだとか、Web クローラーボットのように前回の収集結果を使って次の収集を繰り返すような処理ね。
+
+分散システムでよく使われるこのパターンが再帰とはあまり相性がよろしくない。非同期処理内で再帰を行おうとすると末尾再帰ではなくなるので (スタックではなく) ヒープを消費してしまう。
+
+```scala
+import scala.concurrent.{Await,Future,Promise}
+import scala.concurrent.duration.Duration
+import scala.concurrent.ExecutionContext.Implicits.global
+import scala.util.{Try,Success,Failure}
+
+// 一定時間後に指定された処理を実行する
+val timer = new java.util.Timer()
+def at[T](delay:Long)(f: =>T):Future[T] = {
+ val promise = Promise[T]()
+ timer.schedule(new java.util.TimerTask(){
+ override def run() = Try{ f } match {
+ case Success(x) => promise.success(x)
+ case Failure(ex) => promise.failure(ex)
+ }
+ }, delay)
+ promise.future
+}
+
+// 何かの処理
+def retrieve(count:Int):String = {
+ if(count <= 0) "SUCCESS!" else throw new Throwable()
+}
+
+// 何かの処理が成功するまで繰り返す処理
+def retryLoop(wait:Long, count:Int):Future[String] = {
+ at(wait){ retrieve(count) }.recoverWith{ case _ =>
+ retryLoop(wait, count - 1)
+ }
+}
+Await.result(retryLoop(0, 1000000), Duration.Inf)
+timer.cancel()
+```
+
+`at(){...}` のラムダは非同期で実行されブロック外とスタックを共有していないため、`retryLoop()` の再帰回数が多くなると Stack Overflow ではなく `OutOfMemoryError` が発生します。
+
+```
+scala> Await.result(retryLoop(0, 1000000), Duration.Inf)
+java.lang.OutOfMemoryError: GC overhead limit exceeded
+ at scala.collection.immutable.List.$colon$colon(List.scala:111)
+ at scala.concurrent.impl.Promise$DefaultPromise.scala$concurrent$impl$Promise$DefaultPromise$$dispatchOrAddCallback(Promise.scala:282)
+ ...
+```
+
+この `retryLoop()` だけを Trampoline に書き換えてヒープを消費しないようにできないかと試行錯誤したがどうしたものか?
+
+```scala
+import scala.util.control.TailCalls._
+def retryLoop(wait:Long, count:Int):Future[TailRec[String]] = {
+ val future = at(wait){ retrieve(count) }
+ ???
+}
+```
+
+ 再帰のたびに Future と TailRec が交互に重なるので単純に flatMap で繋げられないというか。
+
+ヒープを消費する理由は Future を合成するために再帰回数分のインスタンスがメモリ上に乗っているのかなとアタリを付けて、再帰だけど Future を合成しないパターンに変更。
+
+```scala
+def retryLoop(wait:Long, count:Int):Future[String] = {
+ val promise = Promise[String]()
+ def _retryLoop(count:Int):Unit = at(wait){
+ Try{ retrieve(count) } match {
+ case Success(x) => promise.success(x)
+ case Failure(ex) => _retryLoop(count - 1)
+ }
+ }
+ _retryLoop(count)
+ promise.future
+}
+```
+
+これなら OutOfMemoryError も発生せずに成功する。まぁこれで逃げておくか。
+
+```
+scala> Await.result(retryLoop(0, 1000000), Duration.Inf)
+res3: String = SUCCESS!
+```