LoginSignup
1
1

More than 5 years have passed since last update.

パターンマッチの作り方(4) 構造体

Last updated at Posted at 2013-06-20

構造体を作って領域を確保し、値を設定し、読み込みを行います。

実装

代入の式と構造体アクセスの式を定義します。
EAssignが変数への束縛や代入、EFieldが構造体のアクセス用の式です。

EFieldのidの箇所は複雑な構造体が扱えるようにする場合は、文字列ではなく式(E)にするとよいのですが、簡易的にするために文字列にしてあります。Fieldと言う名前はJVMのbytecode[1]を参考にしています。

case class EAssign(t: T, a: E, b: E) extends E
case class EField(t: T, id: String, idx: String) extends E

構造体の型を定義します。

case class TStr(types: List[(String, T)]) extends T

構造体のフィールド情報を取得する関数を以下のように定義します。構造体のidを指定すると型が返ります。

object T {
  def find(t:TStr, a: String): (Int, T) = {
    def f(i: Int, xs: List[(String, T)]): (Int, T) = {
      xs match {
        case List() => (-1, Tv)
        case (x, t) :: xs => if (a == x) (i, t) else f(i + 1, xs)
      }
    }
    f(0, t.types)
  }
}

テストコードを追加します。

// 構造体
EVal(TStr(List(("a", Ti(32)), ("b", Ti(32)))), "aa", null),
EAssign(Ti(32), EField(Ti(32), "aa", "a"), ELdc(Ti(32), 9)),
EAssign(Ti(32), EField(Ti(32), "aa", "b"), EId(Ti(32), "c")),
EPrint(Ti(32), EField(Ti(32), "aa", "a")),
EPrint(Ti(32), EField(Ti(32), "aa", "b"))

kNormalにEFieldアクセス用の関数arrを作成します。
この関数はenvから変数を取り出して、フィールド情報を取り出すアセンブラを出力して、結果のレジスタを返します。arrはEFieldやEAssignから呼び出されます。

  def arr(e: E): R = {
    e match {
      case EField(t, id, idx) =>
        env.map(id) match {
          case i:R =>
            val ((n, nt), reg1) = (T.find(i.t.asInstanceOf[TStr],idx), gid(t))
            add(LLField(reg1, i, RN(Ti(64),"0"), RN(nt,""+n)))
            reg1
          case t => throw new Exception("type mismatch " + t)
        }
      case EId(t, id) => env.map(id)
      case _ => throw new Exception("error")
    }
  }

kNormalにEValのTStrだった場合の処理の追加をします。case EVal(t:T, id, null)より前に配置します。

      case EVal(t: TStr, id, null) =>
        emit.llstruct(t)
        env.add(RL(t,id))
        add(LLAlloca(RL(t,id)))
        RL(t,id)

kNormalにEAssignとEFieldの処理の追加をします。どちらもarrを呼び出してポインタを取り出しEAssginならLLStoreで保存し、EFIeldならLLLoadで値を読み込んでいます。

      case EAssign(t, a, b) =>
        (arr(a), f(b)) match {
          case (a, b) =>
            if (t != b.t) throw new Exception("type mismatch " + t + " " + b.t)
            add(LLStore(b, a))
            b
        }
      case a: EField =>
        val a2 = arr(a)
        val b = gid(a2.t)
        add(LLLoad(b, a2))
        b

LLの追加をします。
LLFIeldが構造体のメンバのポインタ取り出し、LLAllocaはスタック上にローカル変数の領域を取得し、LLoadは指定アドレスからの読み込み、LLStoreは指定アドレスへの保存を行います。

case class LLField(id1: R, aid: R, z: R, b: R) extends LL
case class LLAlloca(id: R) extends LL
case class LLLoad(id1: R, id2: R) extends LL
case class LLStore(id1: R, id2: R) extends LL

constFoldへ追加分のLLの命令を追加します。

      case (ls, l @ LLAlloca(id: R)) => l.copy(m(id)) :: ls
      case (ls, l @ LLField(id, id2, id3, id4)) => l.copy(id, m(id2), m(id3), m(id4)) :: ls
      case (ls, l @ LLStore(id1, id2)) => l.copy(m(id1), m(id2)) :: ls
      case (ls, l @ LLLoad(id1, id2)) => l.copy(m(id1), m(id2)) :: ls

emit.sizeに構造体のサイズの計算を追加します

      case t:TStr => llstruct(t)

emitの処理の追加をします。

      case LLField(reg1: R, addr: R, zero: R, a: R) =>
        o(reg1, "getelementptr inbounds " + llt(addr.t) + "* " + llr(addr) + ", " + llt(zero.t) + " " + llr(zero) + ", " + llt(a.t) + " " + llr(a))
      case LLLoad(reg1: R, reg2: R) =>
        o(reg1, "load " + llt(reg1.t) + "* " + llr(reg2))
      case LLStore(reg1: R, reg2: R) =>
        asm("store " + llt(reg1.t) + " " + llr(reg1) + ", " + llt(reg1.t) + "* " + llr(reg2))
      case LLAlloca(reg: R) =>
        o(reg, "alloca " + llt(reg.t))

llstructで構造体がstructsに登録されている場合はその値を返し、登録されていない場合は名前をつけて登録します。structsはemit処理の最初に構造体を出力する必要があるので保存しています。

  var structs: Map[TStr, String] = Map()
  def llstruct(t: TStr): String = {
    if (structs.contains(t)) return structs(t)
    val name = genid("%.struct")
    structs = structs + (t -> name)
    name
  }

使用している構造体の出力を行います。

    structs.foreach { case (t, n) =>
        asm(n + " = type {" + t.types.map { case (a, b) => llt(b) }.mkString(", ") + "}")
    }

以上の処理の追加を行った結果が下のプログラムになります。

package chapter04

import java.io._

sealed trait E {
  def t:T
}
case class ELdc(t:T, i:Long) extends E
case class EBin(t:T, s:String, l:E, r:E) extends E
case class EPrint(t:T, a:E) extends E
case class EBlock(t: T, ls: List[E]) extends E
case class EVal(t: T, id: String, a: E) extends E
case class EId(t: T, id: String) extends E
case class EAssign(t: T, a: E, b: E) extends E
case class EField(t: T, id: String, idx: String) extends E

sealed trait T
case class Ti(i:Int) extends T
case object Tv extends T
case class TFun(t: T, prms: List[T]) extends T
case class TStr(types: List[(String, T)]) extends T

object T {
  def find(t:TStr, a: String): (Int, T) = {
    def f(i: Int, xs: List[(String, T)]): (Int, T) = {
      xs match {
        case List() => (-1, Tv)
        case (x, t) :: xs => if (a == x) (i, t) else f(i + 1, xs)
      }
    }
    f(0, t.types)
  }
}

case class Op(s: String) {
  def apply(t: T, a: E, b: E): E = {
    EBin(t, s, a, b)
  }
}
object EAdd extends Op("add")
object EMul extends Op("mul")

sealed trait R {
  def t:T
  def id:String
}
case class RG(t:T, id: String) extends R
case class RL(t:T, id: String) extends R
case class RR(t:T, id: String) extends R
case class RN(t:T, id: String) extends R

object test {
  def main(argv: Array[String]) {
    try {
      val ast = EBlock(Tv, List(
        EPrint(Ti(32), ELdc(Ti(32), 11)),
        EPrint(Ti(32), EAdd(Ti(32), ELdc(Ti(32), 11), ELdc(Ti(32), 22))),
        // 変数 a 定数
        EVal(Ti(32), "a", ELdc(Ti(32), 11)),
        EPrint(Ti(32), EId(Ti(32), "a")),
        // 変数 b 足し算
        EVal(Ti(32), "b", EAdd(Ti(32), ELdc(Ti(32), 11), ELdc(Ti(32), 22))),
        EPrint(Ti(32), EId(Ti(32), "b")),
        // 変数 c 変数の値
        EVal(Ti(32), "c", EId(Ti(32), "a")),
        EPrint(Ti(32), EId(Ti(32), "c")),
        // 構造体
        EVal(TStr(List(("a", Ti(32)), ("b", Ti(32)))), "aa", null),
        EAssign(Ti(32), EField(Ti(32), "aa", "a"), ELdc(Ti(32), 9)),
        EAssign(Ti(32), EField(Ti(32), "aa", "b"), EId(Ti(32), "c")),
        EPrint(Ti(32), EField(Ti(32), "aa", "a")),
        EPrint(Ti(32), EField(Ti(32), "aa", "b"))
      ))
      println("ast=" + ast)
      val ll = kNormal(ast)
      println("ll=" + ll)
      val ll2 = constFold(ll)
      emit("e.ll", ll2)
      println(exec("llc e.ll -o e.s"))
      println(exec("llvm-gcc -m64 e.s -o e"))
      println(exec("./e"))
    } catch {
      case e:Throwable => e.printStackTrace()
    }
  }
}

object kNormal {
  def gid(t:T): R = {
    RR(t,genid(""))
  }
  var ls: List[LL] = null
  def add(l: LL) {
    ls = l :: ls
  }

  def arr(e: E): R = {
    e match {
      case EField(t, id, idx) =>
        env.map(id) match {
          case i:R =>
            val ((n, nt), reg1) = (T.find(i.t.asInstanceOf[TStr],idx), gid(t))
            add(LLField(reg1, i, RN(Ti(64),"0"), RN(nt,""+n)))
            reg1
          case t => throw new Exception("type mismatch " + t)
        }
      case EId(t, id) => env.map(id)
      case _ => throw new Exception("error")
    }
  }

  def f(a: E): R = {
    a match {
      case EBin(t, op, a1, b1) =>
        (f(a1), f(b1), gid(t)) match {
          case (a, b, id) =>
            if (t != a.t || t != b.t) throw new Exception("type mismatch " + t)
            add(LLBin(id, op, a, b))
            id
        }
      case ELdc(t, i) => RN(t, ""+i)
      case EPrint(t, a) =>
        f(a) match {
          case a =>
            if (t != a.t) throw new Exception("type mismatch t=" + t + " ta=" + a.t)
            add(LLCall(null, RG(TFun(Tv, List(t)), "print_" + emit.llt(t)), List((a.t, a))))
            a
        }
      case EBlock(t, ls) =>
        ls.foldLeft(null: R) {
          case (tid, l) => f(l)
        }
      case EVal(t: TStr, id, null) =>
        emit.llstruct(t)
        env.add(RL(t,id))
        add(LLAlloca(RL(t,id)))
        RL(t,id)
      case e @ EVal(t, id, a) =>
        f(a) match {
          case a =>
            env.add(RL(t, id))
            add(LLAssign(RL(a.t, id), a))
            RL(a.t, id)
        }
      case EId(t, id) => env.map(id)
      case EAssign(t, a, b) =>
        (arr(a), f(b)) match {
          case (a, b) =>
            if (t != b.t) throw new Exception("type mismatch " + t + " " + b.t)
            add(LLStore(b, a))
            b
        }
      case a: EField =>
        val a2 = arr(a)
        val b = gid(a2.t)
        add(LLLoad(b, a2))
        b
    }
  }

  def apply(a: E): List[LL] = {
    ls = List[LL]()
    f(a)
    ls.reverse
  }
}

object env {
  var map = Map[String, R]()
  def add(r: R) {
    map = map + (r.id -> r)
  }
}

sealed trait LL
case class LLCall(id: R, op: R, prms: List[(T, R)]) extends LL
case class LLBin(id: R, op: String, a: R, b: R) extends LL
case class LLAssign(s: R, d: R) extends LL
case class LLField(id1: R, aid: R, z: R, b: R) extends LL
case class LLAlloca(id: R) extends LL
case class LLLoad(id1: R, id2: R) extends LL
case class LLStore(id1: R, id2: R) extends LL

object constFold {
  var map: Map[R, R] = null
  def m(v: R): R = {
    if (map.contains(v)) m(map(v)) else v
  }
  def fs(prms: List[(T, R)]): List[(T, R)] = {
    prms.map {
      case (t, v) => (t, m(v))
    }
  }
  def apply(ls: List[LL]): List[LL] = {
    map = Map()
    ls.foldLeft(List[LL]()) {
      case (ls, l @ LLCall(id, op, prms)) => l.copy(prms = fs(prms)) :: ls
      case (ls, l @ LLBin(id, op, a, b)) => l.copy(a = m(a), b = m(b)) :: ls
      case (ls, l @ LLAssign(s, d)) => map = map + (s -> d); ls
      case (ls, l @ LLAlloca(id: R)) => l.copy(m(id)) :: ls
      case (ls, l @ LLField(id, id2, id3, id4)) => l.copy(id, m(id2), m(id3), m(id4)) :: ls
      case (ls, l @ LLStore(id1, id2)) => l.copy(m(id1), m(id2)) :: ls
      case (ls, l @ LLLoad(id1, id2)) => l.copy(m(id1), m(id2)) :: ls
      case (ls, l) => throw new Exception("error no implementation "+l)
    }.reverse
  }
}

object emit {

  def llt(t:T):String = {
    t match {
      case Ti(i) => "i" + i
      case Tv => "void"
      case TFun(t, ls) => llt(t) + "(" + ls.map(llt).mkString(", ") + ")*"
      case t:TStr => llstruct(t)
    }
  }

  def llr(r:R): String = {
    r match {
      case RG(t,id) => "@" + id
      case RL(t,id) => "%" + id
      case RR(t,id) => "%." + id
      case RN(t,id) => "" + id
    }
  }

  def o(id: R, out: String) {
    if (id != null) asm(llr(id) + " = " + out)
    else asm(out)
  }
  def f(l: LL) {
    l match {
      case LLCall(id, op, prms) =>
        val ps = prms.map { case (a, b) => llt(a) + " " + llr(b) }.mkString(", ")
        o(id, "call " + llt(op.t) + " " + llr(op) + "(" + ps + ") nounwind")
      case LLBin(id, op, a, b) =>
        o(id, op + " " + llt(id.t) + " " + llr(a) + ", " + llr(b))
      case _:LLAssign => throw new Exception("error")
      case LLField(reg1: R, addr: R, zero: R, a: R) =>
        o(reg1, "getelementptr inbounds " + llt(addr.t) + "* " + llr(addr) + ", " + llt(zero.t) + " " + llr(zero) + ", " + llt(a.t) + " " + llr(a))
      case LLLoad(reg1: R, reg2: R) =>
        o(reg1, "load " + llt(reg1.t) + "* " + llr(reg2))
      case LLStore(reg1: R, reg2: R) =>
        asm("store " + llt(reg1.t) + " " + llr(reg1) + ", " + llt(reg1.t) + "* " + llr(reg2))
      case LLAlloca(reg: R) =>
        o(reg, "alloca " + llt(reg.t))
    }
  }

  var structs: Map[TStr, String] = Map()
  def llstruct(t: TStr): String = {
    if (structs.contains(t)) return structs(t)
    val name = genid("%.struct")
    structs = structs + (t -> name)
    name
  }

  def apply(file: String, ls: List[LL]) {
    asm.open(file)
    structs.foreach { case (t, n) =>
        asm(n + " = type {" + t.types.map { case (a, b) => llt(b) }.mkString(", ") + "}")
    }
    asm.label("@.str = private constant [4 x i8] c\"%d\\0A\\00\"")
    asm.label("define void @print_i32(i32 %a) nounwind ssp {")
    asm.label("entry:")
    asm("call i32 (i8*, ...)* @printf(i8* getelementptr inbounds ([4 x i8]* @.str, i64 0, i64 0), i32 %a) nounwind")
    asm("ret void")
    asm.label("}")
    asm.label("define void @print_i8(i8 %a) nounwind ssp {")
    asm.label("entry:")
    asm("call i32 (i8*, ...)* @printf(i8* getelementptr inbounds ([4 x i8]* @.str, i64 0, i64 0), i8 %a) nounwind")
    asm("ret void")
    asm.label("}")

    asm.label("declare i32 @printf(i8*, ...) nounwind")

    asm.label("define i32 @main() nounwind ssp {")
    asm.label("entry:")
    ls.foreach(f)
    asm("ret i32 0")
    asm.label("}")
    asm.close()
  }
}

object genid {
  var id = 0
  def apply(s: String): String = {
    id += 1
    s + id
  }
}
object asm {
  var p: PrintWriter = null
  def open(file: String) {
    p = new PrintWriter(new BufferedWriter(new FileWriter(file)))
  }

  var indent: String = ""

  def apply(s: String, n: String = "") {
    val v = indent + s + "\t" + n + "\n"
    p.print(v)
  }
  def label(s: String) {
    asm.indent = "";
    apply(s)
    asm.indent = "\t";
  }
  def close() {
    p.close()
  }
}

object exec {
  def apply(cmd: String): (Int, String, String) = {
    val p = Runtime.getRuntime().exec(cmd)
    val stdin = (readAll(p.getInputStream()))
    val stderr = (readAll(p.getErrorStream()))
    (p.waitFor(), stdin, stderr)
  }

  def readAll(p: InputStream): String = {
    def f(s: String, i: BufferedReader): String = {
      i.readLine() match {
        case null => s
        case a => f(s + a + "\n", i)
      }
    }
    f("", new BufferedReader(new InputStreamReader(p)))
  }
}

まとめ

構造体の値を設定したり、読み込んだり出来るようになりました。

構造体の中の構造体は扱えるようになっていないと思いますが、このプログラムを拡張する事で扱えるようになります。また、配列も同じように作成する事が出来ます。

参考文献

[1]jasminのinstructionsのページ

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1