2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

ScalaAdvent Calendar 2024

Day 23

Scala プログラムの再帰関数を 3 つの方法でコンパイル時計算しようとしてみた

Last updated at Posted at 2024-12-23

この記事は, Scala Advent Calendar 2024 の 23 日目の記事です.

Scala 3 の インライン (inline) を使ったメタプログラミングで,再帰関数をコンパイル時に計算させようとしていろいろと試行錯誤したので,記事にしてみます.
タイトルが「3 つの方法でコンパイル時計算『しようとしてみた』」なのは,実際に成功したのは 1 つだからです....

実装一覧

フィボナッチ数 fibonacci(45) の計算を,次のパターンで実装してみました.

実装 説明
実装 0 インライン化しない普通の実装
実装 1 フィボナッチ関数を inline def で記述
実装 2 フィボナッチ関数の評価を準クォート '{ ... } で記述
実装 3 フィボナッチ関数の評価を scala.quoted.Expr で記述

TL;DR: うまくコンパイル時計算ができたのは実装 3 だけでした.

実装 0: インライン化しない普通の実装

まずは,インライン化せずに普通に実装したフィボナッチ関数をコンパイルしてみます.

Fibonacci.scala:

def fibonacci(n: Int): Int =
  if n == 0 then 0
  else if n == 1 then 1
  else fibonacci(n - 1) + fibonacci(n - 2)

Main.scala:

@main def main() = {
  val result = fibonacci(45)
  println(s"result: $result")
}

コンパイル時に -Xprint:inlining オプションを指定して,インライン化フェーズ後の構文木を観察してみます.
また,time コマンドを使用してコンパイルの所要時間を計測します.

各ソースのコンパイル結果は次のとおりです.

Fibonacci.scala のコンパイル結果:

$ time scalac -Xprint:inlining Fibonacci.scala
[[syntax trees at end of                  inlining]] // Fibonacci.scala
package <empty> {
  import scala.quoted.{Expr, Quotes}
  final lazy module val Fibonacci$package: Fibonacci$package = new Fibonacci$package()
  @SourceFile("Fibonacci.scala") final module class Fibonacci$package() extends Object() { this: Fibonacci$package.type =>
    private def writeReplace(): AnyRef = new scala.runtime.ModuleSerializationProxy(classOf[Fibonacci$package.type])
    def fibonacci(n: Int): Int = if n.==(0) then 0 else if n.==(1) then 1 else fibonacci(n.-(1)).+(fibonacci(n.-(2)))
  }
}


real	0m1.113s
user	0m4.086s
sys	0m0.232s

Main.scala のコンパイル結果:

$ time scalac -Xprint:inlining Main.scala
[[syntax trees at end of                  inlining]] // Main.scala
package <empty> {
  final lazy module val Main$package: Main$package = new Main$package()
  @SourceFile("Main.scala") final module class Main$package() extends Object() { this: Main$package.type =>
    private def writeReplace(): AnyRef = new scala.runtime.ModuleSerializationProxy(classOf[Main$package.type])
    @main def main: Unit =
      {
        val result: Int = fibonacci(45)
        println(_root_.scala.StringContext.apply(["result: ","" : String]*).s([result : Any]*))
      }
  }
  @SourceFile("Main.scala") final class main() extends Object() {
    <static> def main(args: Array[String]): Unit =
      try main catch 
        {
          case error @ _:scala.util.CommandLineParser.ParseError => scala.util.CommandLineParser.showError(error)
        }
  }
}


real	0m1.174s
user	0m4.464s
sys	0m0.255s

プログラムの実行結果は次のとおりです.

実行結果:

$ time scala run -cp .
result: 1134903170

real	0m4.552s
user	0m5.865s
sys	0m0.210s

コンパイルにかかった時間 (Fibonacci.scalaMain.scala の合計) とプログラム実行にかかった時間は次のとおりでした.

  • コンパイル時間: 2.3 秒
  • プログラム実行時間: 4.6 秒

今回の環境では Scala 3.6.2 を使用しています.
計測対象としている時間は,プログラムのコンパイルと実行に掛かった実時間です (1 回測定しただけの参考値).

実装 1: フィボナッチ関数を inline def で記述

フィボナッチ関数の定義に inline 修飾子を指定してみます.

Fibonacci.scala:

inline def fibonacci(n: Int): Int =
  if n == 0 then 0
  else if n == 1 then 1
  else fibonacci(n - 1) + fibonacci(n - 2)

Main.scala の実装は,実装 0 と同じです.

Fibonacci.scala は次のように普通にコンパイルできました.

Fibonacci.scala のコンパイル結果:

$ time scalac -Xprint:inlining Fibonacci.scala
[[syntax trees at end of                  inlining]] // Fibonacci.scala
package <empty> {
  import scala.quoted.{Expr, Quotes}
  final lazy module val Fibonacci$package: Fibonacci$package = new Fibonacci$package()
  @SourceFile("Fibonacci.scala") final module class Fibonacci$package() extends Object() { this: Fibonacci$package.type =>
    private def writeReplace(): AnyRef = new scala.runtime.ModuleSerializationProxy(classOf[Fibonacci$package.type])
    inline def fibonacci(n: Int): Int = (if n.==(0) then 0 else if n.==(1) then 1 else fibonacci(n.-(1)).+(fibonacci(n.-(2)))):Int
  }
}


real	0m1.033s
user	0m3.647s
sys	0m0.241s

ところが,次に Main.scala をコンパイルしようとしてみたところ,エラーになってしまいました.

Main.scala のコンパイル結果:

$ time scalac -Xprint:inlining Main.scala
-- Error: Main.scala:2:24 ---------------------------------------------------------------------------------------------------------------------------------
 2 |  val result = fibonacci(45)
   |               ^^^^^^^^^^^^^
   |               Maximal number of successive inlines (32) exceeded,
   |               Maybe this is caused by a recursive inline method?
   |               You can use -Xmax-inlines to change the limit.
   |-------------------------------------------------------------------------------------------------------------------------------------------------------
   |Inline stack trace
   |- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
   |This location contains code that was inlined from Main.scala:2
13 |  else fibonacci(n - 1) + fibonacci(n - 2)
   |       ^^^^^^^^^^^^^^^^
   |- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

   -- snip --

   |- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
   |This location contains code that was inlined from Main.scala:2
13 |  else fibonacci(n - 1) + fibonacci(n - 2)
   |       ^^^^^^^^^^^^^^^^
    -------------------------------------------------------------------------------------------------------------------------------------------------------
[[syntax trees at end of                  inlining]] // Main.scala
package <empty> {
  final lazy module val Main$package: Main$package = new Main$package()
  @SourceFile("Main.scala") final module class Main$package() extends Object() { this: Main$package.type =>
    private def writeReplace(): AnyRef = new scala.runtime.ModuleSerializationProxy(classOf[Main$package.type])
    @main def main: Unit =
      {
        val result: Int =
          (
            (
              (
                (
                  (
                    (
                      (
                        (
                          (
                            (
                              (
                                (
                                  (
                                    (
                                      (
                                        (
                                          (
                                            (
                                              (((((...(...):...Int).+(fibonacci(21)):Int).+(fibonacci(22)):Int).+(fibonacci(23)):Int).+(fibonacci(24)):Int)
                                                .+(fibonacci(25))
                                            :Int).+(fibonacci(26))
                                          :Int).+(fibonacci(27))
                                        :Int).+(fibonacci(28))
                                      :Int).+(fibonacci(29))
                                    :Int).+(fibonacci(30))
                                  :Int).+(fibonacci(31))
                                :Int).+(fibonacci(32))
                              :Int).+(fibonacci(33))
                            :Int).+(fibonacci(34))
                          :Int).+(fibonacci(35))
                        :Int).+(fibonacci(36))
                      :Int).+(fibonacci(37))
                    :Int).+(fibonacci(38))
                  :Int).+(fibonacci(39))
                :Int).+(fibonacci(40))
              :Int).+(fibonacci(41))
            :Int).+(fibonacci(42))
          :Int).+(fibonacci(43)):Int
        println(_root_.scala.StringContext.apply(["result: ","" : String]*).s([result : Any]*))
      }
  }
  @SourceFile("Main.scala") final class main() extends Object() {
    <static> def main(args: Array[String]): Unit =
      try main catch 
        {
          case error @ _:scala.util.CommandLineParser.ParseError => scala.util.CommandLineParser.showError(error)
        }
  }
}

1 error found

real	0m1.059s
user	0m4.133s
sys	0m0.254s

オプション -Xmax-inlines 1024 をつけてコンパイルしてみますが,10 分以上応答が帰ってこないため,計測を中断しました....

$ time scalac -Xmax-inlines 1024 -Xprint:inlining Main.scala

コンパイルとプログラム実行にかかった時間は次のとおりです.

  • コンパイル時間: 600 秒以上
  • プログラム実行時間: 未測定

再帰関数の定義に inline 修飾子を指定すると,インライン化も再帰的に行われて,コンパイルにものすごく時間がかかる場合があるようです.

実装 2: フィボナッチ関数の評価を準クォート '{ ... } で記述

フィボナッチ関数の計算を 準クォート (quasiquote) '{ ... } で記述してみます.

Fibonacci.scala:

import scala.quoted.{Expr, Quotes}

inline def fibonacci(n: Int): Int = ${ fibonacciImpl('n) }

def fibonacciImpl(n: Expr[Int])(using Quotes): Expr[Int] = '{
  def fib(x: Int): Int =
    if x == 0 then 0
    else if x == 1 then 1
    else fib(x - 1) + fib(x - 2)
  fib($n)
}

Fibonacci.scala のコンパイル結果:

$ time scalac -Xprint:inlining Fibonacci.scala
[[syntax trees at end of                  inlining]] // Fibonacci.scala
package <empty> {
  import scala.quoted.{Expr, Quotes}
  final lazy module val Fibonacci$package: Fibonacci$package = new Fibonacci$package()
  @SourceFile("Fibonacci.scala") final module class Fibonacci$package() extends Object() { this: Fibonacci$package.type =>
    private def writeReplace(): AnyRef = new scala.runtime.ModuleSerializationProxy(classOf[Fibonacci$package.type])
    inline def fibonacci(n: Int): Int =
      ${
        {
          def $anonfun(using contextual$1: scala.quoted.Quotes): scala.quoted.Expr[Int] = fibonacciImpl('{n}.apply(contextual$1))(contextual$1)
          closure($anonfun)
        }
      }:Int
    def fibonacciImpl(n: scala.quoted.Expr[Int])(using x$2: scala.quoted.Quotes): scala.quoted.Expr[Int] =
      '{
        {
          def fib(x: Int): Int = if x.==(0) then 0 else if x.==(1) then 1 else fib(x.-(1)).+(fib(x.-(2)))
          fib(
            ${
              {
                def $anonfun(using contextual$2: scala.quoted.Quotes): scala.quoted.Expr[Int] = n
                closure($anonfun)
              }
            }
          )
        }
      }.apply(x$2)
  }
}


real	0m1.280s
user	0m4.751s
sys	0m0.266s

Main.scala のコンパイル結果:

$ time scalac -Xprint:inlining Main.scala
[[syntax trees at end of                  inlining]] // Main.scala
package <empty> {
  final lazy module val Main$package: Main$package = new Main$package()
  @SourceFile("Main.scala") final module class Main$package() extends Object() { this: Main$package.type =>
    private def writeReplace(): AnyRef = new scala.runtime.ModuleSerializationProxy(classOf[Main$package.type])
    @main def main: Unit =
      {
        val result: Int =
          {
            def fib(x: Int): Int = (if x.==(0) then 0 else if x.==(1) then 1 else fib(x.-(1)).+(fib(x.-(2))))
            fib(45)
          }:Int
        println(_root_.scala.StringContext.apply(["result: ","" : String]*).s([result : Any]*))
      }
  }
  @SourceFile("Main.scala") final class main() extends Object() {
    <static> def main(args: Array[String]): Unit =
      try main catch 
        {
          case error @ _:scala.util.CommandLineParser.ParseError => scala.util.CommandLineParser.showError(error)
        }
  }
}


real	0m1.231s
user	0m4.596s
sys	0m0.271s

実行結果:

$ time scala run -cp .
result: 1134903170

real	0m4.636s
user	0m6.076s
sys	0m0.210s

コンパイルとプログラム実行にかかった時間は次のとおりです.

  • コンパイル時間: 2.5 秒
  • プログラム実行時間: 4.6 秒

コンパイル時計算は失敗です.コンパイル時のログを見ても,プログラムの実行時間を見ても,コンパイル時計算ができていないことがわかります.

方法 3: フィボナッチ関数の評価を scala.quoted.Expr で記述

フィボナッチ関数の評価を scala.quoted.Expr で記述してみます.

Fibonacci.scala:

import scala.quoted.{Expr, Quotes}

inline def fibonacci(n: Int): Int = ${ fibonacciImpl('n) }

def fibonacciImpl(n: Expr[Int])(using Quotes): Expr[Int] = {
  def fib(x: Int): Int =
    if x == 0 then 0
    else if x == 1 then 1
    else fib(x - 1) + fib(x - 2)
  Expr(fib(n.valueOrAbort))
}

Fibonacci.scala のコンパイル結果:

$ time scalac -Xprint:inlining Fibonacci.scala
[[syntax trees at end of                  inlining]] // Fibonacci.scala
package <empty> {
  import scala.quoted.{Expr, Quotes}
  final lazy module val Fibonacci$package: Fibonacci$package = new Fibonacci$package()
  @SourceFile("Fibonacci.scala") final module class Fibonacci$package() extends Object() { this: Fibonacci$package.type =>
    private def writeReplace(): AnyRef = new scala.runtime.ModuleSerializationProxy(classOf[Fibonacci$package.type])
    inline def fibonacci(n: Int): Int =
      ${
        {
          def $anonfun(using contextual$1: scala.quoted.Quotes): scala.quoted.Expr[Int] = fibonacciImpl('{n}.apply(contextual$1))(contextual$1)
          closure($anonfun)
        }
      }:Int
    def fibonacciImpl(n: scala.quoted.Expr[Int])(using x$2: scala.quoted.Quotes): scala.quoted.Expr[Int] =
      {
        def fib(x: Int): Int = if x.==(0) then 0 else if x.==(1) then 1 else fib(x.-(1)).+(fib(x.-(2)))
        scala.quoted.Expr.apply[Int](fib(x$2.valueOrAbort[Int](n)(scala.quoted.FromExpr.IntFromExpr[Int])))(scala.quoted.ToExpr.IntToExpr[Int])(x$2)
      }
  }
}


real	0m1.342s
user	0m5.831s
sys	0m0.258s

Main.scala のコンパイル結果は次のとおりです.
インライン化のフェーズでコンパイル時に 1134903170 が計算されていることがわかります.

Main.scala のコンパイル結果:

$ time scalac -Xprint:inlining Main.scala
[[syntax trees at end of                  inlining]] // Main.scala
package <empty> {
  final lazy module val Main$package: Main$package = new Main$package()
  @SourceFile("Main.scala") final module class Main$package() extends Object() { this: Main$package.type =>
    private def writeReplace(): AnyRef = new scala.runtime.ModuleSerializationProxy(classOf[Main$package.type])
    @main def main: Unit =
      {
        val result: Int = 1134903170:Int
        println(_root_.scala.StringContext.apply(["result: ","" : String]*).s([result : Any]*))
      }
  }
  @SourceFile("Main.scala") final class main() extends Object() {
    <static> def main(args: Array[String]): Unit =
      try main catch 
        {
          case error @ _:scala.util.CommandLineParser.ParseError => scala.util.CommandLineParser.showError(error)
        }
  }
}


real	0m4.790s
user	0m7.936s
sys	0m0.236s

実行結果:

$ time scala run -cp .
result: 1134903170

real	0m0.931s
user	0m2.254s
sys	0m0.192s

コンパイルとプログラム実行にかかった時間は次のとおりです.

  • コンパイル時間: 6.1 秒
  • プログラム実行時間: 0.9 秒

プログラム実行時間には Java VM の立ち上げにかかる時間も含まれていますので,main メソッドの実行に掛かっている正味の時間はこれよりも短いはずです.

この実装では,期待通りにコンパイル時にフィボナッチ数を計算させることができました!

まとめ

今回試した実装の中では,実装 1 と 実装 2 はうまく動かず,実装 3 で fibonacci(45) のコンパイル時計算を (期待される時間内に) 成功させることができました.

説明 コンパイル時間 (秒) プログラム実行時間 (秒) 結果
実装 0 インライン化しない普通の実装 2.3 4.6
実装 1 フィボナッチ関数を inline def で記述 > 600 n/a コンパイルに非常に時間がかかる
実装 2 フィボナッチ関数の評価を準クォート '{ ... } で記述 2.5 4.6 コンパイル時に計算されない
実装 3 フィボナッチ関数の評価を scala.quoted.Expr で記述 6.1 0.9 コンパイル時に期待通り計算できた

実装 1 や実装 2 はもしかすると実装を工夫すれば動くようになるのかもしれませんが,今回の実装のままでは,コンパイルに非常に時間がかかったり,コンパイル時に計算されなかったりして,期待通り動きませんでした.

理論上は正しい実装になっていても,実装 1 で見たように再帰関数などでインライン化が多発する場合には,コンパイル時間が非常に長くなってしまうことがあるという点に注意が必要です.

ところで,Scala 3 でインライン化が強力にサポートされるようになって,これまでよりも美しくメタプログラミングを書けるになったと感じています.
いろいろ落とし穴 (筆者が勝手にはまっているだけ) はありますが,Scala におけるインライン化や計算ステージの仕様をよく理解したうえで,楽しくメタプログラミングをしていきたいですね!

参考

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?