LoginSignup
2
1

More than 5 years have passed since last update.

パターンマッチの作り方(6) ヴァリアントの型と領域取得

Last updated at Posted at 2013-06-20

今回は、ついにヴァリアントの型について考え領域の取得まで実装します。

使用例

以下が、ヴァリアントを表す型です。

TVariant(List(
  "A"->TStr(List("a"->TInt(32))),
  "B"->TStr(List("a"->TInt(32), "b"->TInt(32))
))

この型で変数領域を確保するのは以下のようにします。

// ヴァリアント
EVal(TVariant(List(
  "A"->TStr(List("a"->Ti(32))),
  "B"->TStr(List("a"->Ti(32), "b"->Ti(32)))
)),"data", null)

確保する領域のサイズは、TVariant内のリストの中の最も大きいサイズを持つ型+4バイトになります。この式をコンパイルする事で領域が確保されるところまでを実装します。

実装

型を定義します。

case class TVariant(ls:List[(String,TStr)]) extends T
%struct.A = type { i32 }
%struct.B = type { i32, i32 }
%struct.Data = type { i32, %struct.B }

%struct.Dataに入れる2番目の構造体の値は、構造体のサイズ計算が出来ないと決めることが出来ません。サイズ計算関数はダミーで1を返す関数を作っておきます。

kNormalに以下のソースを追加

      case EVal(t: TVariant, id, tpl) =>
        val (_, valT, maxT, _) = emit.llvariantInfo(t)
        env.add(RL(t,id))
        add(LLAlloca(RL(t,id)))
        RL(t,id)

サイズ計算とヴァリアント情報の取得

  // サイズ計算
  def size(t:T):Int = {
    t match {
      case t@TStr(ls) =>
        llstruct(t) // 構造体の登録
        ls.foldLeft(0) { case (s,(n,t)) => s + size(t) }
      case Ti(n) => n / 8
      case Tv => 0
      case t:TFun => 8
      case t:TVariant =>
        val (_, _, m, _) = llvariantInfo(t)
        size(m) + 4
    }
  }
  /**
   * ヴァリアント型の情報を取得
   * (型の名前, 構造体, 最大サイズ構造体, 内部の構造体リスト)
   */
  def llvariantInfo(v:TVariant):(String,TStr,TStr,List[TStr]) = {
    val (maxsize, maxt, tys) = v.ls.foldLeft((0, null:TStr, List[TStr]())) {
      case ((n:Int,t,ls),(name:String,vt:T)) =>
        val sizevt = size(vt)
        if (sizevt > n) (sizevt, vt,vt::ls) else (n, t,vt::ls)
    }
    val t = TStr(List("tag"->Ti(32),"data"->maxt))
    (llstruct(t), t, maxt, tys)
  }

  def llvariant(v:TVariant):String = {
    llvariantInfo(v) match {
      case (s,_,_,_) => s

そして、全リストの中で最も大きいサイズと型を求めます。

emitのlltメソッドに追加します。

      case t:TVariant => llvariant(t)

ソース全体が以下になります。

package chapter06

import java.io._

sealed trait E {
  def t:T
}
case class ELdc(t:T, i:Long) extends E
case class EBin(t:T, s:String, l:E, r:E) extends E
case class EPrint(t:T, a:E) extends E
case class EBlock(t: T, ls: List[E]) extends E
case class EVal(t: T, id: String, a: E) extends E
case class EId(t: T, id: String) extends E
case class EAssign(t: T, a: E, b: E) extends E
case class EField(t: T, id: String, idx: String) extends E
case class ETuple(t:T,ls:List[E]) extends E

sealed trait T
case class Ti(i:Int) extends T
case object Tv extends T
case class TFun(t: T, prms: List[T]) extends T
case class TStr(types: List[(String, T)]) extends T
case class TVariant(ls:List[(String,TStr)]) extends T

object T {
  def find(t:TStr, a: String): (Int, T) = {
    def f(i: Int, xs: List[(String, T)]): (Int, T) = {
      xs match {
        case List() => (-1, Tv)
        case (x, t) :: xs => if (a == x) (i, t) else f(i + 1, xs)
      }
    }
    f(0, t.types)
  }
}

case class Op(s: String) {
  def apply(t: T, a: E, b: E): E = {
    EBin(t, s, a, b)
  }
}
object EAdd extends Op("add")
object EMul extends Op("mul")

sealed trait R {
  def t:T
  def id:String
}
case class RG(t:T, id: String) extends R
case class RL(t:T, id: String) extends R
case class RR(t:T, id: String) extends R
case class RN(t:T, id: String) extends R

object test {
  def main(argv: Array[String]) {
    try {
      val ast = EBlock(Tv, List(
        EPrint(Ti(32), ELdc(Ti(32), 11)),
        EPrint(Ti(32), EAdd(Ti(32), ELdc(Ti(32), 11), ELdc(Ti(32), 22))),
        // 変数 a 定数
        EVal(Ti(32), "a", ELdc(Ti(32), 11)),
        EPrint(Ti(32), EId(Ti(32), "a")),
        // 変数 b 足し算
        EVal(Ti(32), "b", EAdd(Ti(32), ELdc(Ti(32), 11), ELdc(Ti(32), 22))),
        EPrint(Ti(32), EId(Ti(32), "b")),
        // 変数 c 変数の値
        EVal(Ti(32), "c", EId(Ti(32), "a")),
        EPrint(Ti(32), EId(Ti(32), "c")),
        // 構造体
        EVal(TStr(List(("a", Ti(32)), ("b", Ti(32)))), "aa", null),
        EAssign(Ti(32), EField(Ti(32), "aa", "a"), ELdc(Ti(32), 9)),
        EAssign(Ti(32), EField(Ti(32), "aa", "b"), EId(Ti(32), "c")),
        EPrint(Ti(32), EField(Ti(32), "aa", "a")),
        EPrint(Ti(32), EField(Ti(32), "aa", "b")),
        // 構造体初期化リテラル
        EVal(TStr(List(("a", Ti(32)), ("b", Ti(32)))), "ab",
          ETuple(TStr(List(("a", Ti(32)), ("b", Ti(32)))),
            List(ELdc(Ti(32),123),ELdc(Ti(32),456)))),
        EPrint(Ti(32), EField(Ti(32), "ab", "a")),
        EPrint(Ti(32), EField(Ti(32), "ab", "b")),

        // ヴァリアント
        EVal(TVariant(List(
          "A"->TStr(List("a"->Ti(32))),
          "B"->TStr(List("a"->Ti(32), "b"->Ti(32)))
        )),"data", null)
      ))
      println("ast=" + ast)
      val ll = kNormal(ast)
      println("ll=" + ll)
      val ll2 = constFold(ll)
      emit("e.ll", ll2)
      println(exec("llc e.ll -o e.s"))
      println(exec("llvm-gcc -m64 e.s -o e"))
      println(exec("./e"))
    } catch {
      case e:Throwable => e.printStackTrace()
    }
  }
}

object kNormal {
  def gid(t:T): R = {
    RR(t,genid(""))
  }
  var ls: List[LL] = null
  def add(l: LL) {
    ls = l :: ls
  }

  def arr(e: E): R = {
    e match {
      case EField(t, id, idx) =>
        env.map(id) match {
          case i:R =>
            val ((n, nt), reg1) = (T.find(i.t.asInstanceOf[TStr],idx), gid(t))
            add(LLField(reg1, i, RN(Ti(64),"0"), RN(nt,""+n)))
            reg1
          case t => throw new Exception("type mismatch " + t)
        }
      case EId(t, id) => env.map(id)
      case _ => throw new Exception("error")
    }
  }

  def f(a: E): R = {
    a match {
      case EBin(t, op, a1, b1) =>
        (f(a1), f(b1), gid(t)) match {
          case (a, b, id) =>
            if (t != a.t || t != b.t) throw new Exception("type mismatch " + t)
            add(LLBin(id, op, a, b))
            id
        }
      case ELdc(t, i) => RN(t, ""+i)
      case EPrint(t, a) =>
        f(a) match {
          case a =>
            if (t != a.t) throw new Exception("type mismatch t=" + t + " ta=" + a.t)
            add(LLCall(null, RG(TFun(Tv, List(t)), "print_" + emit.llt(t)), List((a.t, a))))
            a
        }
      case EBlock(t, ls) =>
        ls.foldLeft(null: R) {
          case (tid, l) => f(l)
        }
      case EVal(t: TStr, id, tpl) =>
        emit.llstruct(t)
        env.add(RL(t,id))
        add(LLAlloca(RL(t,id)))
        tpl match {
          case ETuple(_, ls) =>
            for ((e, (name, t)) <- ls.zip(t.types)) {
              f(EAssign(t, EField(t, id, name), e))
            }
          case null =>
          case _ => throw new Exception("error")
        }
        RL(t,id)
      case EVal(t: TVariant, id, tpl) =>
        val (_, valT, maxT, _) = emit.llvariantInfo(t)
        env.add(RL(t,id))
        add(LLAlloca(RL(t,id)))
        RL(t,id)
      case e @ EVal(t, id, a) =>
        f(a) match {
          case a =>
            env.add(RL(t, id))
            add(LLAssign(RL(a.t, id), a))
            RL(a.t, id)
        }
      case EId(t, id) => env.map(id)
      case EAssign(t, a, b) =>
        (arr(a), f(b)) match {
          case (a, b) =>
            if (t != b.t) throw new Exception("type mismatch " + t + " " + b.t)
            add(LLStore(b, a))
            b
        }
      case a: EField =>
        val a2 = arr(a)
        val b = gid(a2.t)
        add(LLLoad(b, a2))
        b
    }
  }

  def apply(a: E): List[LL] = {
    ls = List[LL]()
    f(a)
    ls.reverse
  }
}

object env {
  var map = Map[String, R]()
  def add(r: R) {
    map = map + (r.id -> r)
  }
}

sealed trait LL
case class LLCall(id: R, op: R, prms: List[(T, R)]) extends LL
case class LLBin(id: R, op: String, a: R, b: R) extends LL
case class LLAssign(s: R, d: R) extends LL
case class LLField(id1: R, aid: R, z: R, b: R) extends LL
case class LLAlloca(id: R) extends LL
case class LLLoad(id1: R, id2: R) extends LL
case class LLStore(id1: R, id2: R) extends LL

object constFold {
  var map: Map[R, R] = null
  def m(v: R): R = {
    if (map.contains(v)) m(map(v)) else v
  }
  def fs(prms: List[(T, R)]): List[(T, R)] = {
    prms.map {
      case (t, v) => (t, m(v))
    }
  }
  def apply(ls: List[LL]): List[LL] = {
    map = Map()
    ls.foldLeft(List[LL]()) {
      case (ls, l @ LLCall(id, op, prms)) => l.copy(prms = fs(prms)) :: ls
      case (ls, l @ LLBin(id, op, a, b)) => l.copy(a = m(a), b = m(b)) :: ls
      case (ls, l @ LLAssign(s, d)) => map = map + (s -> d); ls
      case (ls, l @ LLAlloca(id: R)) => l.copy(m(id)) :: ls
      case (ls, l @ LLField(id, id2, id3, id4)) => l.copy(id, m(id2), m(id3), m(id4)) :: ls
      case (ls, l @ LLStore(id1, id2)) => l.copy(m(id1), m(id2)) :: ls
      case (ls, l @ LLLoad(id1, id2)) => l.copy(m(id1), m(id2)) :: ls
      case (ls, l) => throw new Exception("error no implementation "+l)
    }.reverse
  }
}

object emit {

  def llt(t:T):String = {
    t match {
      case Ti(i) => "i" + i
      case Tv => "void"
      case TFun(t, ls) => llt(t) + "(" + ls.map(llt).mkString(", ") + ")*"
      case t:TStr => llstruct(t)
      case t:TVariant => llvariant(t)
    }
  }
  // サイズ計算
  def size(t:T):Int = {
    t match {
      case t@TStr(ls) =>
        llstruct(t) // 構造体の登録
        ls.foldLeft(0) { case (s,(n,t)) => s + size(t) }
      case Ti(n) => n / 8
      case Tv => 0
      case t:TFun => 8
      case t:TVariant =>
        val (_, _, m, _) = llvariantInfo(t)
        size(m) + 4
    }
  }
  /**
   * ヴァリアント型の情報を取得
   * (型の名前, 構造体, 最大サイズ構造体, 内部の構造体リスト)
   */
  def llvariantInfo(v:TVariant):(String,TStr,TStr,List[TStr]) = {
    val (maxsize, maxt, tys) = v.ls.foldLeft((0, null:TStr, List[TStr]())) {
      case ((n:Int,t,ls),(name:String,vt:T)) =>
        val sizevt = size(vt)
        if (sizevt > n) (sizevt, vt,vt::ls) else (n, t,vt::ls)
    }
    val t = TStr(List("tag"->Ti(32),"data"->maxt))
    (llstruct(t), t, maxt, tys)
  }

  def llvariant(v:TVariant):String = {
    llvariantInfo(v) match {
      case (s,_,_,_) => s
    }
  }

  def llr(r:R): String = {
    r match {
      case RG(t,id) => "@" + id
      case RL(t,id) => "%" + id
      case RR(t,id) => "%." + id
      case RN(t,id) => "" + id
    }
  }

  def o(id: R, out: String) {
    if (id != null) asm(llr(id) + " = " + out)
    else asm(out)
  }
  def f(l: LL) {
    l match {
      case LLCall(id, op, prms) =>
        val ps = prms.map { case (a, b) => llt(a) + " " + llr(b) }.mkString(", ")
        o(id, "call " + llt(op.t) + " " + llr(op) + "(" + ps + ") nounwind")
      case LLBin(id, op, a, b) =>
        o(id, op + " " + llt(id.t) + " " + llr(a) + ", " + llr(b))
      case _:LLAssign => throw new Exception("error")
      case LLField(reg1: R, addr: R, zero: R, a: R) =>
        o(reg1, "getelementptr inbounds " + llt(addr.t) + "* " + llr(addr) + ", " + llt(zero.t) + " " + llr(zero) + ", " + llt(a.t) + " " + llr(a))
      case LLLoad(reg1: R, reg2: R) =>
        o(reg1, "load " + llt(reg1.t) + "* " + llr(reg2))
      case LLStore(reg1: R, reg2: R) =>
        asm("store " + llt(reg1.t) + " " + llr(reg1) + ", " + llt(reg1.t) + "* " + llr(reg2))
      case LLAlloca(reg: R) =>
        o(reg, "alloca " + llt(reg.t))
    }
  }

  var structs: Map[TStr, String] = Map()
  def llstruct(t: TStr): String = {
    if (structs.contains(t)) return structs(t)
    val name = genid("%.struct")
    structs = structs + (t -> name)
    name
  }

  def apply(file: String, ls: List[LL]) {
    asm.open(file)
    structs.foreach { case (t, n) =>
        asm(n + " = type {" + t.types.map { case (a, b) => llt(b) }.mkString(", ") + "}")
    }
    asm.label("@.str = private constant [4 x i8] c\"%d\\0A\\00\"")
    asm.label("define void @print_i32(i32 %a) nounwind ssp {")
    asm.label("entry:")
    asm("call i32 (i8*, ...)* @printf(i8* getelementptr inbounds ([4 x i8]* @.str, i64 0, i64 0), i32 %a) nounwind")
    asm("ret void")
    asm.label("}")
    asm.label("define void @print_i8(i8 %a) nounwind ssp {")
    asm.label("entry:")
    asm("call i32 (i8*, ...)* @printf(i8* getelementptr inbounds ([4 x i8]* @.str, i64 0, i64 0), i8 %a) nounwind")
    asm("ret void")
    asm.label("}")

    asm.label("declare i32 @printf(i8*, ...) nounwind")

    asm.label("define i32 @main() nounwind ssp {")
    asm.label("entry:")
    ls.foreach(f)
    asm("ret i32 0")
    asm.label("}")
    asm.close()
  }
}

object genid {
  var id = 0
  def apply(s: String): String = {
    id += 1
    s + id
  }
}
object asm {
  var p: PrintWriter = null
  def open(file: String) {
    p = new PrintWriter(new BufferedWriter(new FileWriter(file)))
  }

  var indent: String = ""

  def apply(s: String, n: String = "") {
    val v = indent + s + "\t" + n + "\n"
    p.print(v)
  }
  def label(s: String) {
    asm.indent = "";
    apply(s)
    asm.indent = "\t";
  }
  def close() {
    p.close()
  }
}

object exec {
  def apply(cmd: String): (Int, String, String) = {
    val p = Runtime.getRuntime().exec(cmd)
    val stdin = (readAll(p.getInputStream()))
    val stderr = (readAll(p.getErrorStream()))
    (p.waitFor(), stdin, stderr)
  }

  def readAll(p: InputStream): String = {
    def f(s: String, i: BufferedReader): String = {
      i.readLine() match {
        case null => s
        case a => f(s + a + "\n", i)
      }
    }
    f("", new BufferedReader(new InputStreamReader(p)))
  }
}

まとめ

今回はヴァリアントの型の変数領域の確保まで実装をしました。
次回は、ヴァリアント値の初期化を行います。

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1