Scala関数型デザイン&プログラミング ―Scalazコントリビューターによる関数型徹底ガイドという本を買って、Scalaを勉強していますが、ちょっと驚いた部分があったので記事にしました。
StackOverflowが発生するコード
本の中で、代数的データ型の説明に際してList
を作るという部分があり、次のようなサンプルコードが載っています。
package fpinscala.datastructures
sealed trait List[+A]
case object Nil extends List[Nothing]
case class Cons[+A](head : A, tail : List[A]) extends List[A]
object List {
def apply[A](as : A*) : List[A] =
if (as.isEmpty) Nil
else Cons(as.head, apply(as.tail : _*))
}
import fpinscala.datastructures.List
object TestList {
def main(args : Array[String]) : Unit = {
val xs = List(1, 2, 3, 4)
println(xs)
}
}
次のような出力になります。
Cons(1,Cons(2,Cons(3,Cons(4,Nil))))
ただし、次のようにたくさんの引数を注入するとStackOverflowが発生します。
import fpinscala.datastructures.List
object TestListWithStackOverflow {
def main(args : Array[String]) : Unit = {
val xs = List( (1 to 10000).toSeq : _* )
println(xs)
}
}
継続渡しスタイル(CPS)で末尾再帰最適化?
List.scala のapply
関数はelse
文で最後に行う処理がCons(as.head, apply(as.tail : _*))
となっており、末尾再帰ではありません。これにより処理系の末尾再帰最適化(Tail call elimination)が受けられなくなり、結果StackOverflowに陥ったというのが単純な仮説です。この関数を 継続渡しスタイル(CPS) を用いて末尾再帰な関数にすれば、末尾再帰最適化によってStackOverflowを回避できるように思われます。先程の関数をCPSで書き直すと次のようになります。
def apply[A](as : A*) : List[A] = {
def loop(k : List[A] => List[A], xs : Seq[A]) : List[A] =
if (xs.isEmpty) k(Nil)
else loop(x => k(Cons(xs.head, x)), xs.tail)
loop(x => x, as)
}
ここでloop
関数はk(Nil)
かloop(x => k(Cons(xs.head, x)), xs.tail)
のどちらかであり、loop
関数の最後に呼び出されるのが自分自身なので末尾再帰関数となります。
こちらのバージョンで TestList.scala を再び実行してみると同じ結果となり、ひとまず機能的にはCPS変換した後と前とで同じであることを確かめられると思います。さて、肝心のStackOverflowは解決されたのかを確かめるために、 TestListWithStackOverflow.scala を実行してみましょう。実はこれを行ってもStackOverflowが表示されて失敗してしまいます。
どうしてこのような結果になったかというと、CPS変換の際に用いる関数k
に問題があります。この関数k
は、呼び出されると前に生成された関数k
を呼び出します。つまり関数k
は関数apply
の引数$n$に対して$n$数珠つなぎになっているということになり、これがScalaのコールスタックを逼迫させます。
トランポリン化による末尾再帰最適化
トランポリン化 と呼ばれる次のような方法を用いて、数珠つなぎになった関数を切断してStackOverflowを回避します。
package fpinscala.tailrec
sealed trait TailRec[A] {
final def run : A = this match {
case Return(v) => v
case Suspend(k) => k().run
}
}
case class Return[A](v : A) extends TailRec[A]
case class Suspend[A](resume : () => TailRec[A]) extends TailRec[A]
def apply[A] (as : A*) : List[A] = {
def loop(k : List[A] => TailRec[List[A]], xs : Seq[A]) : TailRec[List[A]] =
if (xs.isEmpty) k(Nil)
else loop( x => Suspend(() => k(Cons(xs.head, x))), xs.tail)
loop( x => Return(x), as ).run
}
無事に動くものが完成しました。ストリーム(遅延リスト)を作ったことがある人がいれば、あれに似ていると思ったのではないでしょうか。
まとめ
Scala関数型デザイン&プログラミング ―Scalazコントリビューターによる関数型徹底ガイドや後に紹介する参考文献では、これをさらにFreeモナドへ拡張して話を進めていましたが、まだそこまで読んでいないので読んだらまたまとめようかなと思います。