概要
以下の記事を参考にScala3のマクロアノテーションを使ってメソッドの戻り値をLRUキャッシュ (Least Recently Used. 直近に参照した時刻が一番古いものを捨てるキャッシュ) で保持するコードを書きました。
-
Macro annotations in Scala 3
- Scala3のマクロアノテーションでメソッドの戻り値をキャッシュする実装の紹介している
-
Scala Interview Series: Implement LRU cache in scala
- ScalaでのLRUキャッシュの実装を紹介している
注意
Scala3のマクロアノテーションはMacroAnnotation
を継承して定義できます。しかし、MacroAnnotation
はバージョン3.3.0-RC2
の実験的な機能であり、アノテーションクラスとそれを利用するクラスに@experimental
アノテーションを付ける必要があります。
LRUキャッシュの実装
import scala.collection.mutable.LinkedHashMap
class LRUCache[K, V](size: Int):
private val cache = LinkedHashMap.empty[K, V]
private val queue = new DoublyLinkedList[K]
def get(key: K): Option[V] =
cache.get(key) match
case Some(value) =>
queue.moveToFront(key)
Some(value)
case None => None
def put(key: K, value: V): Unit =
if (cache.contains(key)) queue.moveToFront(key)
else
if (cache.size >= size)
val last = queue.removeLast()
cache.remove(last)
queue.addToFront(key)
cache.put(key, value)
class DoublyLinkedList[K]:
case class Node(key: K, var prev: Node, var next: Node)
private val head = Node(null.asInstanceOf[K], null, null)
private val tail = Node(null.asInstanceOf[K], head, null)
head.next = tail
private val map = scala.collection.mutable.Map.empty[K, Node]
def addToFront(key: K): Unit =
val node = Node(key, head, head.next)
head.next.prev = node
head.next = node
map.put(key, node)
def remove(node: Node): Unit =
node.prev.next = node.next
node.next.prev = node.prev
map.remove(node.key)
def removeLast(): K =
val last = tail.prev
remove(last)
last.key
def moveToFront(key: K): Unit =
map.get(key) match
case Some(node) =>
remove(node)
addToFront(key)
case None => ()
ほぼ、Scala Interview Series: Implement LRU cache in scalaのコードのコピペです。
サイズが指定可能でシンプルなLRUキャッシュです。
マクロアノテーションの実装
import scala.annotation.{MacroAnnotation, experimental}
import scala.quoted.*
@experimental
class lruCached(size: Int) extends MacroAnnotation:
// [1]
override def transform(using quotes: Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] =
import quotes.reflect.*
// [2]
tree match
case DefDef(name, params, returnType, Some(rhs)) =>
// [3]
val flattenedParams = params.map(_.params).flatten
val paramTermRefs = flattenedParams.map(_.asInstanceOf[ValDef].symbol.termRef)
val paramTuple = Expr.ofTupleFromSeq(paramTermRefs.map(Ident(_).asExpr))
// [4]
(paramTuple, rhs.asExpr) match
case ('{ $p: paramTupleType }, '{ $r: rhsType }) =>
// [5]
val cacheName = Symbol.freshName(name + "Cache")
val cacheType = TypeRepr.of[LRUCache[paramTupleType, rhsType]]
val cacheRhs = '{ LRUCache[paramTupleType, rhsType](${ Expr(size) }) }.asTerm
val cacheSymbol = Symbol.newVal(tree.symbol.owner, cacheName, cacheType, Flags.Private, Symbol.noSymbol)
val cache = ValDef(cacheSymbol, Some(cacheRhs))
val cacheRef = Ref(cacheSymbol).asExprOf[LRUCache[paramTupleType, rhsType]]
// [6]
def buildNewRhs(using q: Quotes) =
import q.reflect.*
'{
val key = ${ paramTuple.asExprOf[paramTupleType] }
$cacheRef.get(key) match
case Some(value) =>
value
case None =>
val result = ${ rhs.asExprOf[rhsType] }
$cacheRef.put(key, result)
result
}
val newRhs = buildNewRhs(using tree.symbol.asQuotes).asTerm
// [7]
val expandedMethod = DefDef.copy(tree)(name, params, returnType, Some(newRhs))
List(cache, expandedMethod)
case _ =>
report.error("Annottee must be a method")
List(tree)
[1] transformメソッド
override def transform(using quotes: Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] =
MacroAnnotation
が提供するtransform
メソッドはtree
の定義を変換し、新しい定義を追加できます。
tree
の型Definition
はソースコード内の定義のツリー構造を表現しており、ClassDef
、TypeDef
、DefDef
、ValDef
のいずれかです。
[2] Definitionのパターンマッチ
tree match
case DefDef(name, params, returnType, Some(rhs)) =>
...
case _ =>
report.error("Annottee must be a method")
List(tree)
tree
がDefDef
型 (ソースコード内のメソッド定義を表すツリー構造)であるかチェックしています。
メソッド以外にアノテーションが付けられている場合、コンパイルエラーとなります。
case DefDef(name, params, returnType, Some(rhs))
はunapply
メソッドを呼んでおり、unapply
のシグネチャは以下のように定義されています。
def unapply(ddef: DefDef): (String, List[ParamClause], TypeTree, Option[Term])
[3] メソッドのパラメータの処理
val flattenedParams = params.map(_.params).flatten
val paramTermRefs = flattenedParams.map(_.asInstanceOf[ValDef].symbol.termRef)
val paramTuple = Expr.ofTupleFromSeq(paramTermRefs.map(Ident(_).asExpr))
メソッドのパラメータをタプルの式の変換しています。
[4] 式のパターンマッチ
(paramTuple, rhs.asExpr) match
case ('{ $p: paramTupleType }, '{ $r: rhsType }) =>
パラメータのタプル式のメソッド内の式をパターンマッチしています。$p
はパラメータのタプル式の値、$r
はメソッド内の式の値を取り出します。
quoted code block ('{ ... }
)については、ドキュメントをご参照ください。
[5] LRUキャッシュを定義
val cacheName = Symbol.freshName(name + "Cache")
val cacheType = TypeRepr.of[LRUCache[paramTupleType, rhsType]]
val cacheRhs = '{ LRUCache[paramTupleType, rhsType](${ Expr(size) }) }.asTerm
val cacheSymbol = Symbol.newVal(tree.symbol.owner, cacheName, cacheType, Flags.Private, Symbol.noSymbol)
val cache = ValDef(cacheSymbol, Some(cacheRhs))
val cacheRef = Ref(cacheSymbol).asExprOf[LRUCache[paramTupleType, rhsType]]
Symbol
は等しい文字列の一意のオブジェクトを取得するためのクラスです。
Ref
は定義への参照を表すツリー構造です。
[6] 変換先のメソッドを定義
def buildNewRhs(using q: Quotes) =
import q.reflect.*
'{
val key = ${ paramTuple.asExprOf[paramTupleType] }
$cacheRhs.get(key) match
case Some(value) =>
value
case None =>
val result = ${ rhs.asExprOf[rhsType] }
$cacheRhs.put(key, result)
result
}
val newRhs = buildNewRhs(using tree.symbol.asQuotes).asTerm
タプル化したパラメータをkeyとしてキャッシュに問い合わせ、戻り値があればそれを返却、なければ元のメソッドを実行し戻り値をキャッシュに保存します。
[7] メソッドを置き換えて返却する
val expandedMethod = DefDef.copy(tree)(name, params, returnType, Some(newRhs))
List(cache, expandedMethod)
expandedMethod
を入力メソッドのコピーとして作成し、元のメソッドを置き換えます。
そして、LRUキャッシュの定義と合わせて返却します。
参考: https://www.codecentric.de/wissens-hub/blog/macro-annotations-in-scala-3
動作確認
import scala.annotation.experimental
@experimental
object Main extends App:
class Calculator:
@lruCached(10)
def add(x: Int, y: Int): Int =
Thread.sleep(3000)
x + y
val calculator = Calculator()
val time1 = System.currentTimeMillis
println(calculator.add(1, 2))
val time2 = System.currentTimeMillis
println(calculator.add(2, 3))
val time3 = System.currentTimeMillis
println(calculator.add(1, 2))
val time4 = System.currentTimeMillis
println("1度目のaddの処理時間: " + (time2 - time1) + " ミリ秒")
println("2度目のaddの処理時間: " + (time3 - time2) + " ミリ秒")
println("3度目のaddの処理時間: " + (time4 - time3) + " ミリ秒")
// 出力結果
// 3
// 5
// 3
// 1度目のaddの処理時間: 3067 ミリ秒
// 2度目のaddの処理時間: 3003 ミリ秒
// 3度目のaddの処理時間: 2 ミリ秒
3度目のaddメソッドの呼び出しは1度目と引数が同じであり、戻り値をキャッシュから取得できたためaddメソッドの計算がスキップされました。
まとめ
Scala3のマクロアノテーションでメソッドの戻り値をLRUキャッシュで保持するコードを書いてみました。
Scala3はマクロ用の型が多くて理解が大変ですが、その分Scala2より安全にマクロが書けるようになっていると思います。
MacroAnnotation
はまだ実験版なので、一般的に使えるようになることを願いましょう。