この記事は, Scala Advent Calendar 2024 の 23 日目の記事です.
Scala 3 の インライン (inline) を使ったメタプログラミングで,再帰関数をコンパイル時に計算させようとしていろいろと試行錯誤したので,記事にしてみます.
タイトルが「3 つの方法でコンパイル時計算『しようとしてみた』」なのは,実際に成功したのは 1 つだからです....
実装一覧
フィボナッチ数 fibonacci(45)
の計算を,次のパターンで実装してみました.
実装 | 説明 |
---|---|
実装 0 | インライン化しない普通の実装 |
実装 1 | フィボナッチ関数を inline def で記述 |
実装 2 | フィボナッチ関数の評価を準クォート '{ ... } で記述 |
実装 3 | フィボナッチ関数の評価を scala.quoted.Expr で記述 |
TL;DR: うまくコンパイル時計算ができたのは実装 3 だけでした.
実装 0: インライン化しない普通の実装
まずは,インライン化せずに普通に実装したフィボナッチ関数をコンパイルしてみます.
Fibonacci.scala
:
def fibonacci(n: Int): Int =
if n == 0 then 0
else if n == 1 then 1
else fibonacci(n - 1) + fibonacci(n - 2)
Main.scala
:
@main def main() = {
val result = fibonacci(45)
println(s"result: $result")
}
コンパイル時に -Xprint:inlining
オプションを指定して,インライン化フェーズ後の構文木を観察してみます.
また,time
コマンドを使用してコンパイルの所要時間を計測します.
各ソースのコンパイル結果は次のとおりです.
Fibonacci.scala
のコンパイル結果:
$ time scalac -Xprint:inlining Fibonacci.scala
[[syntax trees at end of inlining]] // Fibonacci.scala
package <empty> {
import scala.quoted.{Expr, Quotes}
final lazy module val Fibonacci$package: Fibonacci$package = new Fibonacci$package()
@SourceFile("Fibonacci.scala") final module class Fibonacci$package() extends Object() { this: Fibonacci$package.type =>
private def writeReplace(): AnyRef = new scala.runtime.ModuleSerializationProxy(classOf[Fibonacci$package.type])
def fibonacci(n: Int): Int = if n.==(0) then 0 else if n.==(1) then 1 else fibonacci(n.-(1)).+(fibonacci(n.-(2)))
}
}
real 0m1.113s
user 0m4.086s
sys 0m0.232s
Main.scala
のコンパイル結果:
$ time scalac -Xprint:inlining Main.scala
[[syntax trees at end of inlining]] // Main.scala
package <empty> {
final lazy module val Main$package: Main$package = new Main$package()
@SourceFile("Main.scala") final module class Main$package() extends Object() { this: Main$package.type =>
private def writeReplace(): AnyRef = new scala.runtime.ModuleSerializationProxy(classOf[Main$package.type])
@main def main: Unit =
{
val result: Int = fibonacci(45)
println(_root_.scala.StringContext.apply(["result: ","" : String]*).s([result : Any]*))
}
}
@SourceFile("Main.scala") final class main() extends Object() {
<static> def main(args: Array[String]): Unit =
try main catch
{
case error @ _:scala.util.CommandLineParser.ParseError => scala.util.CommandLineParser.showError(error)
}
}
}
real 0m1.174s
user 0m4.464s
sys 0m0.255s
プログラムの実行結果は次のとおりです.
実行結果:
$ time scala run -cp .
result: 1134903170
real 0m4.552s
user 0m5.865s
sys 0m0.210s
コンパイルにかかった時間 (Fibonacci.scala
と Main.scala
の合計) とプログラム実行にかかった時間は次のとおりでした.
- コンパイル時間: 2.3 秒
- プログラム実行時間: 4.6 秒
今回の環境では Scala 3.6.2 を使用しています.
計測対象としている時間は,プログラムのコンパイルと実行に掛かった実時間です (1 回測定しただけの参考値).
実装 1: フィボナッチ関数を inline def
で記述
フィボナッチ関数の定義に inline
修飾子を指定してみます.
Fibonacci.scala
:
inline def fibonacci(n: Int): Int =
if n == 0 then 0
else if n == 1 then 1
else fibonacci(n - 1) + fibonacci(n - 2)
Main.scala
の実装は,実装 0 と同じです.
Fibonacci.scala
は次のように普通にコンパイルできました.
Fibonacci.scala
のコンパイル結果:
$ time scalac -Xprint:inlining Fibonacci.scala
[[syntax trees at end of inlining]] // Fibonacci.scala
package <empty> {
import scala.quoted.{Expr, Quotes}
final lazy module val Fibonacci$package: Fibonacci$package = new Fibonacci$package()
@SourceFile("Fibonacci.scala") final module class Fibonacci$package() extends Object() { this: Fibonacci$package.type =>
private def writeReplace(): AnyRef = new scala.runtime.ModuleSerializationProxy(classOf[Fibonacci$package.type])
inline def fibonacci(n: Int): Int = (if n.==(0) then 0 else if n.==(1) then 1 else fibonacci(n.-(1)).+(fibonacci(n.-(2)))):Int
}
}
real 0m1.033s
user 0m3.647s
sys 0m0.241s
ところが,次に Main.scala
をコンパイルしようとしてみたところ,エラーになってしまいました.
Main.scala
のコンパイル結果:
$ time scalac -Xprint:inlining Main.scala
-- Error: Main.scala:2:24 ---------------------------------------------------------------------------------------------------------------------------------
2 | val result = fibonacci(45)
| ^^^^^^^^^^^^^
| Maximal number of successive inlines (32) exceeded,
| Maybe this is caused by a recursive inline method?
| You can use -Xmax-inlines to change the limit.
|-------------------------------------------------------------------------------------------------------------------------------------------------------
|Inline stack trace
|- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|This location contains code that was inlined from Main.scala:2
13 | else fibonacci(n - 1) + fibonacci(n - 2)
| ^^^^^^^^^^^^^^^^
|- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
-- snip --
|- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
|This location contains code that was inlined from Main.scala:2
13 | else fibonacci(n - 1) + fibonacci(n - 2)
| ^^^^^^^^^^^^^^^^
-------------------------------------------------------------------------------------------------------------------------------------------------------
[[syntax trees at end of inlining]] // Main.scala
package <empty> {
final lazy module val Main$package: Main$package = new Main$package()
@SourceFile("Main.scala") final module class Main$package() extends Object() { this: Main$package.type =>
private def writeReplace(): AnyRef = new scala.runtime.ModuleSerializationProxy(classOf[Main$package.type])
@main def main: Unit =
{
val result: Int =
(
(
(
(
(
(
(
(
(
(
(
(
(
(
(
(
(
(
(((((...(...):...Int).+(fibonacci(21)):Int).+(fibonacci(22)):Int).+(fibonacci(23)):Int).+(fibonacci(24)):Int)
.+(fibonacci(25))
:Int).+(fibonacci(26))
:Int).+(fibonacci(27))
:Int).+(fibonacci(28))
:Int).+(fibonacci(29))
:Int).+(fibonacci(30))
:Int).+(fibonacci(31))
:Int).+(fibonacci(32))
:Int).+(fibonacci(33))
:Int).+(fibonacci(34))
:Int).+(fibonacci(35))
:Int).+(fibonacci(36))
:Int).+(fibonacci(37))
:Int).+(fibonacci(38))
:Int).+(fibonacci(39))
:Int).+(fibonacci(40))
:Int).+(fibonacci(41))
:Int).+(fibonacci(42))
:Int).+(fibonacci(43)):Int
println(_root_.scala.StringContext.apply(["result: ","" : String]*).s([result : Any]*))
}
}
@SourceFile("Main.scala") final class main() extends Object() {
<static> def main(args: Array[String]): Unit =
try main catch
{
case error @ _:scala.util.CommandLineParser.ParseError => scala.util.CommandLineParser.showError(error)
}
}
}
1 error found
real 0m1.059s
user 0m4.133s
sys 0m0.254s
オプション -Xmax-inlines 1024
をつけてコンパイルしてみますが,10 分以上応答が帰ってこないため,計測を中断しました....
$ time scalac -Xmax-inlines 1024 -Xprint:inlining Main.scala
コンパイルとプログラム実行にかかった時間は次のとおりです.
- コンパイル時間: 600 秒以上
- プログラム実行時間: 未測定
再帰関数の定義に inline
修飾子を指定すると,インライン化も再帰的に行われて,コンパイルにものすごく時間がかかる場合があるようです.
実装 2: フィボナッチ関数の評価を準クォート '{ ... }
で記述
フィボナッチ関数の計算を 準クォート (quasiquote) '{ ... }
で記述してみます.
Fibonacci.scala
:
import scala.quoted.{Expr, Quotes}
inline def fibonacci(n: Int): Int = ${ fibonacciImpl('n) }
def fibonacciImpl(n: Expr[Int])(using Quotes): Expr[Int] = '{
def fib(x: Int): Int =
if x == 0 then 0
else if x == 1 then 1
else fib(x - 1) + fib(x - 2)
fib($n)
}
Fibonacci.scala
のコンパイル結果:
$ time scalac -Xprint:inlining Fibonacci.scala
[[syntax trees at end of inlining]] // Fibonacci.scala
package <empty> {
import scala.quoted.{Expr, Quotes}
final lazy module val Fibonacci$package: Fibonacci$package = new Fibonacci$package()
@SourceFile("Fibonacci.scala") final module class Fibonacci$package() extends Object() { this: Fibonacci$package.type =>
private def writeReplace(): AnyRef = new scala.runtime.ModuleSerializationProxy(classOf[Fibonacci$package.type])
inline def fibonacci(n: Int): Int =
${
{
def $anonfun(using contextual$1: scala.quoted.Quotes): scala.quoted.Expr[Int] = fibonacciImpl('{n}.apply(contextual$1))(contextual$1)
closure($anonfun)
}
}:Int
def fibonacciImpl(n: scala.quoted.Expr[Int])(using x$2: scala.quoted.Quotes): scala.quoted.Expr[Int] =
'{
{
def fib(x: Int): Int = if x.==(0) then 0 else if x.==(1) then 1 else fib(x.-(1)).+(fib(x.-(2)))
fib(
${
{
def $anonfun(using contextual$2: scala.quoted.Quotes): scala.quoted.Expr[Int] = n
closure($anonfun)
}
}
)
}
}.apply(x$2)
}
}
real 0m1.280s
user 0m4.751s
sys 0m0.266s
Main.scala
のコンパイル結果:
$ time scalac -Xprint:inlining Main.scala
[[syntax trees at end of inlining]] // Main.scala
package <empty> {
final lazy module val Main$package: Main$package = new Main$package()
@SourceFile("Main.scala") final module class Main$package() extends Object() { this: Main$package.type =>
private def writeReplace(): AnyRef = new scala.runtime.ModuleSerializationProxy(classOf[Main$package.type])
@main def main: Unit =
{
val result: Int =
{
def fib(x: Int): Int = (if x.==(0) then 0 else if x.==(1) then 1 else fib(x.-(1)).+(fib(x.-(2))))
fib(45)
}:Int
println(_root_.scala.StringContext.apply(["result: ","" : String]*).s([result : Any]*))
}
}
@SourceFile("Main.scala") final class main() extends Object() {
<static> def main(args: Array[String]): Unit =
try main catch
{
case error @ _:scala.util.CommandLineParser.ParseError => scala.util.CommandLineParser.showError(error)
}
}
}
real 0m1.231s
user 0m4.596s
sys 0m0.271s
実行結果:
$ time scala run -cp .
result: 1134903170
real 0m4.636s
user 0m6.076s
sys 0m0.210s
コンパイルとプログラム実行にかかった時間は次のとおりです.
- コンパイル時間: 2.5 秒
- プログラム実行時間: 4.6 秒
コンパイル時計算は失敗です.コンパイル時のログを見ても,プログラムの実行時間を見ても,コンパイル時計算ができていないことがわかります.
方法 3: フィボナッチ関数の評価を scala.quoted.Expr
で記述
フィボナッチ関数の評価を scala.quoted.Expr で記述してみます.
Fibonacci.scala
:
import scala.quoted.{Expr, Quotes}
inline def fibonacci(n: Int): Int = ${ fibonacciImpl('n) }
def fibonacciImpl(n: Expr[Int])(using Quotes): Expr[Int] = {
def fib(x: Int): Int =
if x == 0 then 0
else if x == 1 then 1
else fib(x - 1) + fib(x - 2)
Expr(fib(n.valueOrAbort))
}
Fibonacci.scala
のコンパイル結果:
$ time scalac -Xprint:inlining Fibonacci.scala
[[syntax trees at end of inlining]] // Fibonacci.scala
package <empty> {
import scala.quoted.{Expr, Quotes}
final lazy module val Fibonacci$package: Fibonacci$package = new Fibonacci$package()
@SourceFile("Fibonacci.scala") final module class Fibonacci$package() extends Object() { this: Fibonacci$package.type =>
private def writeReplace(): AnyRef = new scala.runtime.ModuleSerializationProxy(classOf[Fibonacci$package.type])
inline def fibonacci(n: Int): Int =
${
{
def $anonfun(using contextual$1: scala.quoted.Quotes): scala.quoted.Expr[Int] = fibonacciImpl('{n}.apply(contextual$1))(contextual$1)
closure($anonfun)
}
}:Int
def fibonacciImpl(n: scala.quoted.Expr[Int])(using x$2: scala.quoted.Quotes): scala.quoted.Expr[Int] =
{
def fib(x: Int): Int = if x.==(0) then 0 else if x.==(1) then 1 else fib(x.-(1)).+(fib(x.-(2)))
scala.quoted.Expr.apply[Int](fib(x$2.valueOrAbort[Int](n)(scala.quoted.FromExpr.IntFromExpr[Int])))(scala.quoted.ToExpr.IntToExpr[Int])(x$2)
}
}
}
real 0m1.342s
user 0m5.831s
sys 0m0.258s
Main.scala
のコンパイル結果は次のとおりです.
インライン化のフェーズでコンパイル時に 1134903170
が計算されていることがわかります.
Main.scala
のコンパイル結果:
$ time scalac -Xprint:inlining Main.scala
[[syntax trees at end of inlining]] // Main.scala
package <empty> {
final lazy module val Main$package: Main$package = new Main$package()
@SourceFile("Main.scala") final module class Main$package() extends Object() { this: Main$package.type =>
private def writeReplace(): AnyRef = new scala.runtime.ModuleSerializationProxy(classOf[Main$package.type])
@main def main: Unit =
{
val result: Int = 1134903170:Int
println(_root_.scala.StringContext.apply(["result: ","" : String]*).s([result : Any]*))
}
}
@SourceFile("Main.scala") final class main() extends Object() {
<static> def main(args: Array[String]): Unit =
try main catch
{
case error @ _:scala.util.CommandLineParser.ParseError => scala.util.CommandLineParser.showError(error)
}
}
}
real 0m4.790s
user 0m7.936s
sys 0m0.236s
実行結果:
$ time scala run -cp .
result: 1134903170
real 0m0.931s
user 0m2.254s
sys 0m0.192s
コンパイルとプログラム実行にかかった時間は次のとおりです.
- コンパイル時間: 6.1 秒
- プログラム実行時間: 0.9 秒
プログラム実行時間には Java VM の立ち上げにかかる時間も含まれていますので,main メソッドの実行に掛かっている正味の時間はこれよりも短いはずです.
この実装では,期待通りにコンパイル時にフィボナッチ数を計算させることができました!
まとめ
今回試した実装の中では,実装 1 と 実装 2 はうまく動かず,実装 3 で fibonacci(45)
のコンパイル時計算を (期待される時間内に) 成功させることができました.
説明 | コンパイル時間 (秒) | プログラム実行時間 (秒) | 結果 | |
---|---|---|---|---|
実装 0 | インライン化しない普通の実装 | 2.3 | 4.6 | |
実装 1 | フィボナッチ関数を inline def で記述 |
> 600 | n/a | コンパイルに非常に時間がかかる |
実装 2 | フィボナッチ関数の評価を準クォート '{ ... } で記述 |
2.5 | 4.6 | コンパイル時に計算されない |
実装 3 | フィボナッチ関数の評価を scala.quoted.Expr で記述 |
6.1 | 0.9 | コンパイル時に期待通り計算できた |
実装 1 や実装 2 はもしかすると実装を工夫すれば動くようになるのかもしれませんが,今回の実装のままでは,コンパイルに非常に時間がかかったり,コンパイル時に計算されなかったりして,期待通り動きませんでした.
理論上は正しい実装になっていても,実装 1 で見たように再帰関数などでインライン化が多発する場合には,コンパイル時間が非常に長くなってしまうことがあるという点に注意が必要です.
ところで,Scala 3 でインライン化が強力にサポートされるようになって,これまでよりも美しくメタプログラミングを書けるになったと感じています.
いろいろ落とし穴 (筆者が勝手にはまっているだけ) はありますが,Scala におけるインライン化や計算ステージの仕様をよく理解したうえで,楽しくメタプログラミングをしていきたいですね!