Posted at

GADTでエレガントに抽象構文木の階層を作る

いきなりですが、何かの言語とかの抽象構文木を表すときに、いつも困っている問題がありました。たとえば、PEGにおける式は、

e ::= e e       //連接

| e '/' e //選択
| e '*' //反復(0回以上)
| e '+' //反復(1回以上)
| '&' e //肯定先読み
| '!' e //否定先読み
| "..." //文字列

のように表現できますが、これを素直にScalaで表現すると次のようになります。

sealed trait Expression

case class Alpha(ch: Char) extends Expression
case class Choice(lhs: Expression, rhs: Expression) extends Expression
case class Sequence(lhs: Expression, rhs: Expression) extends Expression
case class Repeat0(body: Expression) extends Expression
case class Not(body: Expression) extends Expression
case class Repeat1(body: Expression) extends Expression
case class And(body: Expression) extends Expression

とりあえず、extends Expressionの部分が冗長なのはScala 3のenumで解消されるのでよいとします。ここからが問題なのですが、このうち、Repeat1(反復(0回以上))、And(肯定先読み)は、単なる糖衣構文とみなすことができます。具体的には、

e+ ==> e e*

&e ==> !!e

というシンプルな変換で置き換えることができます(実際には、反復も否定先読みも除去できるのですが、話が複雑になるのでおいておきます)。

ここで、こういう抽象構文木を処理するときに、糖衣構文を除去したものに対して何らかの解析処理なり解釈をしたいことがあるのですが、糖衣構文を除去した後も型は変わらないため、パターンマッチの網羅性検査がうまく働かない(正確には過剰な警告を出してしまう)のです。

たとえば、上記の構文糖衣を除去したものを解釈したいと思って、

def display(e: Expression): String = e match {

case Alpha(ch) => s"'${ch}"
case Choice(lhs, rhs) => s"(${display(lhs)}/${display(rhs)})"
case Sequence(lhs, rhs) => s"(${display(lhs)} ${display(rhs)})"
case Repeat0(body) => s"(${display(body)})*"
case Not(body) => s"!(${display(body)})"
}

とすると、Repeat1Andの分岐が足りない旨の警告が出てしまいます。そして、これまでは、そういうケースに対して、

case _ => sys.error("should not reach here")

のような分岐を付け足して警告を抑止していたのですが、明らかに来ないことがわかっているもののためにこの分岐を書くのは気持ちが悪いものがありました。

で、この問題をなんとかできないかなと考えていたところ、「これ、GADTでいける感じの問題では?」と思って、なんとなくざっと書いてみたら、ほぼ一発でうまく動きました。以下がそのコードです。

object PEG {

sealed trait Category
class Full extends Category
final class Core extends Full

sealed trait Expression[+A]
case class Alpha(ch: Char) extends Expression[Core]
case class Choice[+A <: Category](lhs: Expression[A], rhs: Expression[A]) extends Expression[A]
case class Sequence[+A <: Category](lhs: Expression[A], rhs: Expression[A]) extends Expression[A]
case class Repeat0[+A <: Category](body: Expression[A]) extends Expression[A]
case class Not[+A <: Category](body: Expression[A]) extends Expression[A]
case class Repeat1(body: Expression[Full]) extends Expression[Full]
case class And(body: Expression[Full]) extends Expression[Full]

def desugar(e: Expression[Full]): Expression[Core] = e match {
case Alpha(ch) => Alpha(ch)
case Choice(lhs, rhs) => Choice(desugar(lhs), desugar(rhs))
case Sequence(lhs, rhs) => Sequence(desugar(lhs), desugar(rhs))
case Repeat0(body) => Repeat0(desugar(body))
case Not(body) => Not(desugar(body))
case Repeat1(body) => Sequence(desugar(body), Repeat0(desugar(body)))
case And(body) => Not(Not(desugar(body)))
}

def display(e: Expression[Core]): String = e match {
case Alpha(ch) => s"'${ch}"
case Choice(lhs, rhs) => s"(${display(lhs)}/${display(rhs)})"
case Sequence(lhs, rhs) => s"(${display(lhs)} ${display(rhs)})"
case Repeat0(body) => s"(${display(body)})*"
case Not(body) => s"!(${display(body)})"
}

def main(args: Array[String]): Unit = {
val e: Expression[Full] = Sequence(
And(Choice(Alpha('a'), Alpha('b'))), Repeat1(Choice(Alpha('a'), Alpha('b')))
)
//(!(!(('a/'b))) (('a/'b) (('a/'b))*))
println(display(desugar(e)))
}
}

先ほどと違って、Expressionが共変な型パラメータAを持っていて、各ノードはそれを、特定の型パラメータに特化した形で継承しているのがポイントです。ここで、シンタックスシュガーの部分には、Expression[Full]という型が付き、コアの式にはExpression[Core]という型がつくので、上記のdisplayメソッドは警告がでず(網羅性検査がされた上で通ります)、無事、長い間悩んでいた問題(の一部)が解決されたのでした。