7
0

[Scala3] マクロアノテーションでメソッドの戻り値をLRUキャッシュ

Last updated at Posted at 2023-12-16

概要

以下の記事を参考にScala3のマクロアノテーションを使ってメソッドの戻り値をLRUキャッシュ (Least Recently Used. 直近に参照した時刻が一番古いものを捨てるキャッシュ) で保持するコードを書きました。

注意

Scala3のマクロアノテーションはMacroAnnotationを継承して定義できます。しかし、MacroAnnotationはバージョン3.3.0-RC2の実験的な機能であり、アノテーションクラスとそれを利用するクラスに@experimentalアノテーションを付ける必要があります。

LRUキャッシュの実装

LRUCache.scala
import scala.collection.mutable.LinkedHashMap

class LRUCache[K, V](size: Int):
  private val cache = LinkedHashMap.empty[K, V]
  private val queue = new DoublyLinkedList[K]

  def get(key: K): Option[V] =
    cache.get(key) match
      case Some(value) =>
        queue.moveToFront(key)
        Some(value)
      case None => None

  def put(key: K, value: V): Unit =
    if (cache.contains(key)) queue.moveToFront(key)
    else
      if (cache.size >= size)
        val last = queue.removeLast()
        cache.remove(last)
      queue.addToFront(key)
    cache.put(key, value)

class DoublyLinkedList[K]:
  case class Node(key: K, var prev: Node, var next: Node)

  private val head = Node(null.asInstanceOf[K], null, null)
  private val tail = Node(null.asInstanceOf[K], head, null)
  head.next = tail

  private val map = scala.collection.mutable.Map.empty[K, Node]

  def addToFront(key: K): Unit =
    val node = Node(key, head, head.next)
    head.next.prev = node
    head.next = node
    map.put(key, node)

  def remove(node: Node): Unit =
    node.prev.next = node.next
    node.next.prev = node.prev
    map.remove(node.key)

  def removeLast(): K =
    val last = tail.prev
    remove(last)
    last.key

  def moveToFront(key: K): Unit =
    map.get(key) match
      case Some(node) =>
        remove(node)
        addToFront(key)
      case None => ()

ほぼ、Scala Interview Series: Implement LRU cache in scalaのコードのコピペです。
サイズが指定可能でシンプルなLRUキャッシュです。

マクロアノテーションの実装

lruCached.scala
import scala.annotation.{MacroAnnotation, experimental}
import scala.quoted.*

@experimental
class lruCached(size: Int) extends MacroAnnotation:
  // [1]
  override def transform(using quotes: Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] =
    import quotes.reflect.*

    // [2]
    tree match
      case DefDef(name, params, returnType, Some(rhs)) =>
        // [3]
        val flattenedParams = params.map(_.params).flatten
        val paramTermRefs = flattenedParams.map(_.asInstanceOf[ValDef].symbol.termRef)
        val paramTuple = Expr.ofTupleFromSeq(paramTermRefs.map(Ident(_).asExpr))

        // [4]
        (paramTuple, rhs.asExpr) match
          case ('{ $p: paramTupleType }, '{ $r: rhsType }) =>
            // [5]
            val cacheName = Symbol.freshName(name + "Cache")
            val cacheType = TypeRepr.of[LRUCache[paramTupleType, rhsType]]
            val cacheRhs = '{ LRUCache[paramTupleType, rhsType](${ Expr(size) }) }.asTerm
            val cacheSymbol = Symbol.newVal(tree.symbol.owner, cacheName, cacheType, Flags.Private, Symbol.noSymbol)
            val cache = ValDef(cacheSymbol, Some(cacheRhs))
            val cacheRef = Ref(cacheSymbol).asExprOf[LRUCache[paramTupleType, rhsType]]

            // [6]
            def buildNewRhs(using q: Quotes) =
              import q.reflect.*
              '{
                val key = ${ paramTuple.asExprOf[paramTupleType] }
                $cacheRef.get(key) match
                  case Some(value) =>
                    value
                  case None =>
                    val result = ${ rhs.asExprOf[rhsType] }
                    $cacheRef.put(key, result)
                    result
              }
            val newRhs = buildNewRhs(using tree.symbol.asQuotes).asTerm
            // [7]
            val expandedMethod = DefDef.copy(tree)(name, params, returnType, Some(newRhs))
            List(cache, expandedMethod)
      case _ =>
        report.error("Annottee must be a method")
        List(tree)

[1] transformメソッド

override def transform(using quotes: Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] =

MacroAnnotationが提供するtransformメソッドはtreeの定義を変換し、新しい定義を追加できます。
treeの型Definitionはソースコード内の定義のツリー構造を表現しており、ClassDefTypeDefDefDefValDefのいずれかです。

[2] Definitionのパターンマッチ

tree match
      case DefDef(name, params, returnType, Some(rhs)) =>
        ...
      case _ =>
        report.error("Annottee must be a method")
        List(tree)

treeDefDef型 (ソースコード内のメソッド定義を表すツリー構造)であるかチェックしています。
メソッド以外にアノテーションが付けられている場合、コンパイルエラーとなります。

case DefDef(name, params, returnType, Some(rhs))unapplyメソッドを呼んでおり、unapplyのシグネチャは以下のように定義されています。

def unapply(ddef: DefDef): (String, List[ParamClause], TypeTree, Option[Term])

[3] メソッドのパラメータの処理

val flattenedParams = params.map(_.params).flatten
val paramTermRefs = flattenedParams.map(_.asInstanceOf[ValDef].symbol.termRef)
val paramTuple = Expr.ofTupleFromSeq(paramTermRefs.map(Ident(_).asExpr))

メソッドのパラメータをタプルの式の変換しています。

[4] 式のパターンマッチ

(paramTuple, rhs.asExpr) match
  case ('{ $p: paramTupleType }, '{ $r: rhsType }) =>

パラメータのタプル式のメソッド内の式をパターンマッチしています。$pはパラメータのタプル式の値、$rはメソッド内の式の値を取り出します。
quoted code block ('{ ... })については、ドキュメントをご参照ください。

[5] LRUキャッシュを定義

val cacheName = Symbol.freshName(name + "Cache")
val cacheType = TypeRepr.of[LRUCache[paramTupleType, rhsType]]
val cacheRhs = '{ LRUCache[paramTupleType, rhsType](${ Expr(size) }) }.asTerm
val cacheSymbol = Symbol.newVal(tree.symbol.owner, cacheName, cacheType, Flags.Private, Symbol.noSymbol)
val cache = ValDef(cacheSymbol, Some(cacheRhs))
val cacheRef = Ref(cacheSymbol).asExprOf[LRUCache[paramTupleType, rhsType]]

Symbolは等しい文字列の一意のオブジェクトを取得するためのクラスです。
Refは定義への参照を表すツリー構造です。

[6] 変換先のメソッドを定義

def buildNewRhs(using q: Quotes) =
  import q.reflect.*
  '{
    val key = ${ paramTuple.asExprOf[paramTupleType] }
    $cacheRhs.get(key) match
      case Some(value) =>
        value
      case None =>
        val result = ${ rhs.asExprOf[rhsType] }
        $cacheRhs.put(key, result)
        result
  }
val newRhs = buildNewRhs(using tree.symbol.asQuotes).asTerm

タプル化したパラメータをkeyとしてキャッシュに問い合わせ、戻り値があればそれを返却、なければ元のメソッドを実行し戻り値をキャッシュに保存します。

[7] メソッドを置き換えて返却する

val expandedMethod = DefDef.copy(tree)(name, params, returnType, Some(newRhs))
List(cache, expandedMethod)

expandedMethod を入力メソッドのコピーとして作成し、元のメソッドを置き換えます。
そして、LRUキャッシュの定義と合わせて返却します。

参考: https://www.codecentric.de/wissens-hub/blog/macro-annotations-in-scala-3

動作確認

Main.scala
import scala.annotation.experimental

@experimental
object Main extends App:

  class Calculator:
    @lruCached(10)
    def add(x: Int, y: Int): Int =
      Thread.sleep(3000)
      x + y

  val calculator = Calculator()

  val time1 = System.currentTimeMillis
  
  println(calculator.add(1, 2))
  val time2 = System.currentTimeMillis
  
  println(calculator.add(2, 3))
  val time3 = System.currentTimeMillis
  
  println(calculator.add(1, 2))
  val time4 = System.currentTimeMillis

  println("1度目のaddの処理時間: " + (time2 - time1) + " ミリ秒")
  println("2度目のaddの処理時間: " + (time3 - time2) + " ミリ秒")
  println("3度目のaddの処理時間: " + (time4 - time3) + " ミリ秒")

  // 出力結果
  // 3
  // 5
  // 3
  // 1度目のaddの処理時間: 3067 ミリ秒
  // 2度目のaddの処理時間: 3003 ミリ秒
  // 3度目のaddの処理時間: 2 ミリ秒

3度目のaddメソッドの呼び出しは1度目と引数が同じであり、戻り値をキャッシュから取得できたためaddメソッドの計算がスキップされました。

まとめ

Scala3のマクロアノテーションでメソッドの戻り値をLRUキャッシュで保持するコードを書いてみました。
Scala3はマクロ用の型が多くて理解が大変ですが、その分Scala2より安全にマクロが書けるようになっていると思います。
MacroAnnotationはまだ実験版なので、一般的に使えるようになることを願いましょう。

7
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
0