LoginSignup
11
10

More than 5 years have passed since last update.

トランポリン化とアキュムレータの性能比較

Last updated at Posted at 2015-04-14

Scalaのトランポリン化については以前記事にしましたが、CPS変換については次のような指摘がありました。

アキュムレータを使って解決できる場合、そちらの方が早そうという仮説ですが、これがScalaではどうなるのか試すことにしました。

実験に使うコード

引数$n$を取り、1を$n$回足し算する次のようなプログラムを用います。

TailRec.scala
package fpinscala.tailrec

sealed trait TailRec[A] {
  final def run : A = this match {
    case Return(v)  => v
    case Suspend(k) => k().run
  }
}

case class Return[A](v : A) extends TailRec[A]
case class Suspend[A](resume : () => TailRec[A]) extends TailRec[A]
Counting.scala
import fpinscala.tailrec._

object Counting {
  def normal (n : Int) : Int =
    if (n == 0) 0
    else 1 + normal(n - 1)

  def cps (n : Int) : Int = {
    def loop (i : Int, k : Int => TailRec[Int]) : TailRec[Int] =
      if (i == 0) k(0)
      else loop(i - 1, x => Suspend(() => k(1 + x)))

    loop(n, t => Return(t)).run
  }

  def accum (n : Int) : Int = {
    def loop (i : Int, a : Int) : Int =
      if (i == 0) a
      else loop(i - 1, a + 1)

    loop(n, 0)
  }

  def main(args : Array[String]) : Unit = {
    val n = 100000
    val s1 = System.currentTimeMillis()
    accum(n)
    println(System.currentTimeMillis() - s1)

    val s2 = System.currentTimeMillis()
    cps(n)
    println(System.currentTimeMillis() - s2)

    val s3 = System.currentTimeMillis()
    normal(n)
    println(System.currentTimeMillis() - s3)
  }
}

結果

function \ $n$ 100 1000 10000 100000 1000000 10000000
normal 0 0 1 × × ×
accum 240 234 243 268 239 228
cps (trampoline) 8 7 12 29 133 2264

まとめ

どうやら次のことが言えそうです。

  • $n$が少ない時はnormalが最速
  • $n$が多くはないが、normalがStackOverflowするような場合はcpsが有力
  • どんな$n$に対してもaccumは一定の速度

$n$が増えてもaccumの速度が一定なのはどうしてなのか分からないですが、ひとまずこのような結果となりました。

追記

@xuwei_k さんの指摘に基づいて、複数回実行してその平均時間を取る方法の結果も掲載します。

コード

def main(args : Array[String]) : Unit = {
  def getTime[A] (f : A => Any, i : A, n : Int) : Double = {
    val s = System.currentTimeMillis()
    for (_ <- 1 to n)
      f(i)

    (System.currentTimeMillis() - s).toDouble / n.toDouble
  }

  for (i <- 2 to 7)
    print(s"${getTime(accum, pow(10, i).toInt, 10)},")

  println()

  for (i <- 2 to 7)
    print(s"${getTime(cps, pow(10, i).toInt, 10)},")
}

結果

function \ $n$ 100 1000 10000 100000 1000000 10000000
accum 0.1 0.1 0.2 0.3 0.0 0.0
cps (trampoline) 0.8 0.3 1.6 4.6 71.7 932.5

@xuwei_k さんのご指摘どおり、平均を取るとどの回数でもaccumが早いようです。

11
10
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
10