前提
こんにちは、インターン生の鍋島です。
この記事は自分のデザインパターンの理解とScalaのコーディング練習を目的としたものです。
まだ稚拙な文章ですが、暖かく見守ってもらえると嬉しいです。
もし記事中に間違いがあれば、ぜひコメントで指摘お願いします。
それでは初めていきましょう!
Compositeパターン
今回はCompositeパターンの解説をしていきます。
Compositeパターンとは構造に関するデザインパターンで、木構造などの再帰的なデータ構造を表現するのに適しています。
今回は木構造を用いで説明するので、まずは木構造について整理しておきましょう。
内部ノード: 子ノードを持つノード、葉ノード以外のノードのこと。
葉ノード: 子ノードを持たないノードのこと。
根ノード: 親ノードを持たないノードのこと。
エッジ: ノードとノードをつなぐもの。
例えば、木構造のデータに対して再帰処理をする場合を考えてみてください。
内部ノードの場合、自身が持っているノードにアクセスする処理、葉ノードは自身の値を参照する処理になり、内部ノードと葉ノードで振る舞いが違います。
なので再帰処理をする場合は、条件分岐などをして内部ノードが持っている子ノードが内部ノードのバターンと葉ノードパターンの処理を両方実装する必要が出てきます。
そこでCompositeパターンを適用することによって、内部ノードと葉ノードが同じインターフェースを持ちます。
そのため、再帰処理を実装する時に、ノードが内部ノードか葉ノードかを意識する必要がありません。
サンプルコード
今回は数式を文字列で与えると計算してくれる電卓を実装し、Compositeパターンの解説をしていきます。
この電卓は、四則演算、剰余(%)、カッコに対応しています。
例えば、"( 1 + ( 6 / 3 ) ) * ( 7 - 3 )"を与えると12と出力します。
それではコードを見ていきましょう。
サンプルコードのボリュームがすごいですが。。。
import scala.collection.mutable
/*
電卓仕様
四則演算、%、^と()に対応している
与えられる数値の範囲は正の値のみ
必ず空白区切りにする
*/
object Main extends App {
val in = "( 1 + ( 6 / 3 ) ) * ( 7 - 3 )"
// val in = "( 1 * 2 ) * ( 3 * 4 )"
// val in = "3 + 6 * ( 6 + 5 ) * ( 7 * 2 )"
// val in = "602 % 0 * 3"
// val in = "1 + 4"
val compiler = new Compiler(in)
val result = compiler.compile()
println(result)
}
/*
計算式を受け取り木構造に変換し、変換した木構造から計算結果を出す
*/
class Compiler(in: String) {
private val nodeParser = new NodeParser
val partList: List[Part] = nodeParser.parse(in)
private val stack: mutable.Stack[Node] = mutable.Stack()
private val tempStack: mutable.Stack[IterableOnce[Node]] = mutable.Stack()
private val priorityQueue: mutable.Queue[Int] = mutable.Queue()
/*
PartをNodeに変換し木構造を作成する
Partはカッコを含むため、木構造ではカッコをなくすため
*/
def compile(parts: List[Part] = partList, currentPriority: Int = 0): Int = parts match {
/*
スタックに3ノード以上ある場合は、終了せずに再帰を続ける
*/
case ::(head: Node, Nil) if stack.length > 3 =>
val partialAst = createPartAst(head, stack.pop(), stack.pop())
compile(List(partialAst), 50)
/*
木構造が完成したら、計算結果を返す
*/
case ::(head: Node, Nil) =>
val ast = createPartAst(head, stack.pop(), stack.pop())
ast.result()
/*
) が来た場合は、stack内のノードを計算しNumberにした後に、退避させていたNodeを戻す
*/
case ::(head: Brackets, tail) if head.value == ")" =>
val partialAst = createPartAst(stack.pop(), stack.pop(), stack.pop())
stack.pushAll(tempStack.pop())
val currentPriority = priorityQueue.dequeue()
compile(new Number(partialAst.result().toString) +: tail, currentPriority)
/*
( が来た場合は、stackをすべて退避させる
*/
case ::(head: Brackets, tail) if head.value == "(" =>
tempStack.push(stack.popAll())
priorityQueue += currentPriority
compile(tail, 0)
/*
数字のPartの場合は必ずstackさせる
*/
case ::(head: Number, tail) =>
stack.push(head)
compile(tail, currentPriority)
/*
現在の優先順位より高いものが来た場合はstackする
*/
case ::(head: Sign[_], tail) if currentPriority < head.priority =>
stack.push(head)
compile(tail, head.priority)
/*
初めに現在の優先順位より低いものが来た場合は、stack内の要素で部分木を作成する
最後に現在のstackの中身にNodeが一つしか存在しない場合、次の要素をstackするために現在の優先順位を下げる
*/
case ::(head: Sign[_], tail) if currentPriority >= head.priority =>
val partialAst = createPartAst(stack.pop(), stack.pop(), stack.pop())
stack.push(partialAst)
val priority = if (stack.length == 1) 0 else currentPriority
compile(head :: tail, priority)
case node => throw new Exception(s"no match $node")
}
/*
ノードと符号を受け取ることで部分木を作成する
*/
def createPartAst(left: Node, symbol: Node, right: Node): Node = symbol match {
case _: Addition[_] => new Addition[Some[Node]](Some(left), Some(right))
case _: Multiplication[_] => new Multiplication[Some[Node]](Some(left), Some(right))
case _: Subtraction[_] => new Subtraction[Some[Node]](Some(left), Some(right))
case _: Division[_] => new Division[Some[Node]](Some(left), Some(right))
case _: Remainder[_] => new Remainder[Some[Node]](Some(left), Some(right))
}
/*
デバック用
*/
def print(): Unit = {
println("====")
partList.foreach(node => println(node.value))
println("====")
}
}
/*
文字列で受け取った式を一つずつNodeに変換していく
*/
class NodeParser() {
private var numberDigit: String = ""
private var spaceFlg: Boolean = true
private var partSeq: List[Part] = List()
private var bracketsCount: Int = 0
/*
入力された文字をPartに変換をしていく
*/
def parse(in: String): List[Part] = {
val isNum = """[0-9]""".r
val isSign = """[\*/\+\-\^%]""".r
val isBrackets = """[\(\)]""".r
in.split("").foreach(s => s match {
case isNum() => patternInt(s)
case isSign() => partSeq = partSeq :+ patternSign(s)
case isBrackets() => partSeq = partSeq :+ patternBrackets(s)
case s if s == " " =>
spaceFlg = true
if (numberDigit != "") {
partSeq = partSeq :+ createNode(numberDigit)
numberDigit = ""
}
case _ =>
})
if (numberDigit != "") {
partSeq = partSeq :+ createNode(numberDigit)
numberDigit = ""
}
partSeq
}
/*
数字の場合、数字は2桁以上の場合ああるので、それを考慮している
*/
private def patternInt(s: String): Unit = partSeq.lastOption match {
case Some(_: Sign[_]) if !spaceFlg =>
throw new Exception("空白の位置が正しくありません")
case Some(_: Number) =>
throw new Exception("文法が正しくありません")
case _ =>
spaceFlg = false
numberDigit += s
}
/*
記号が入力された時の処理
*/
private def patternSign(s: String): Part = partSeq.lastOption match {
case _ if !spaceFlg =>
throw new Exception("空白の位置が正しくありません")
case Some(_: Sign[_]) =>
throw new Exception("文法が正しくありません")
case None =>
throw new Exception("文法が正しくありません")
case _ =>
spaceFlg = false
createNode(s)
}
/*
カッコが入ってきた時の処理
*/
private def patternBrackets(s: String): Part = partSeq.lastOption match {
case _ if !spaceFlg =>
throw new Exception("空白の位置が正しくありません")
case Some(_: Number) if (s == ")" && bracketsCount > 0) =>
bracketsCount -= 1
createNode(s)
case Some(_: Brackets) if (s == ")" && bracketsCount > 0) =>
bracketsCount += 1
createNode(s)
case Some(_: Sign[_]) if s == "(" =>
bracketsCount += 1
createNode(s)
case None =>
bracketsCount += 1
createNode(s)
case _ =>
throw new Exception("文法が正しくありません")
}
/*
受け取った文字からNodeを作成している
*/
private def createNode(s: String): Part = s match {
case "(" => new Brackets(s)
case ")" => new Brackets(s)
case "+" => new Addition[None.type](None, None)
case "-" => new Subtraction[None.type](None, None)
case "/" => new Division[None.type](None, None)
case "*" => new Multiplication[None.type](None, None)
case "%" => new Remainder[None.type](None, None)
case i => new Number(i)
}
/*
デバック用
*/
def print(): Unit = {
println("====")
partSeq.foreach(node => println(node.value))
println("====")
}
}
/*
数字と記号の抽象
*/
trait Node extends Part {
def result(): Int
}
trait Part {
val value: String
}
trait Sign[T <: Option[Node]] extends Node {
val left: T
val right: T
val priority: Int
}
class Addition[T <: Option[Node]](val left: T, val right: T) extends Sign[T]{
override val value: String = "+"
override val priority: Int = 1
override def result(): Int = right.fold(0)(_.result()) + left.fold(0)(_.result())
}
class Multiplication[T <: Option[Node]](val left: T, val right: T) extends Sign[T] {
override val value: String = "*"
override val priority: Int = 2
override def result():Int = right.fold(0)(_.result()) * left.fold(0)(_.result())
}
class Subtraction[T <: Option[Node]](val left: T, val right: T) extends Sign[T] {
override val value: String = "-"
override val priority: Int = 1
override def result():Int = right.fold(0)(_.result()) - left.fold(0)(_.result())
}
class Division[T <: Option[Node]](val left: T, val right: T) extends Sign[T] {
override val value: String = "/"
override val priority: Int = 2
override def result():Int = {
if (left.fold(0)(_.result()) == 0) throw new Exception("右辺に0が使用されています")
right.fold(0)(_.result()) / left.fold(0)(_.result())
}
}
class Remainder[T <: Option[Node]](val left: T, val right: T) extends Sign[T] {
override val value: String = "%"
override val priority: Int = 2
override def result():Int = {
if (left.fold(0)(_.result()) == 0) throw new Exception("右辺に0が使用されています")
right.fold(0)(_.result()) % left.fold(0)(_.result())
}
}
class Number(val value: String) extends Node {
override def result(): Int = value.toInt
}
class Brackets(val value: String) extends Part
解説
それでは解説をしていきます。
まずは電卓で計算したい値をCompilerのコンストラクタに渡しています。
渡された文字列はコンストラクタでNodeParserのparse関数に渡しています。
parse関数ではパターンマッチと正規表現を使って、数値や符号ごとに処理を切り替えてプロパティの値を書き換え、Seq[Part]を作成しています。
※"1 * ( 3 + 5 )"の時のSeq[Part]のイメージ
続いてCompilerのcompile関数で、先程作成したSeq[Part]を使用して木構造を作成し、その木構造の計算結果を返します。
ここはSeq[Part]からPartを一つずつ取得し、取り出したPartのサブクラスを見て処理を分岐しています。
Signクラスの場合は、現在の優先順位と取り出したPartの優先順位を比べて低い場合はスタックに積む、高い場合はスタックから3つNodeを取り出し部分木を作成し、部分木をスタックに積み直すというふうに処理をしています。
Bracketsクラスの場合は、現在のスタックをすべて退避させて、閉じカッコが来た時に退避させていた要素を戻し、カッコ内の結果をスタックします。
このように木構造を作っていき、最後に出来た木構造に対してresult関数を実行することで全体の計算結果が得られます。
ここからがCompositeパターンの本題に入ります。
まず今回のサンプルコードでは、NodeでCompositeパターンを使用しております。
Compositeパターンを活用することで、内部ノードと葉ノードを意識せずに再帰処理を書くことが出来ます。
これを今回のサンプルコードに当てはめると内部ノードがSign、葉ノードがNumberとなります。
そして、その共通のインターフェースであるNodeトレイトに、resultが定義されています。
すると、内部ノードを操作する場合も、葉ノードを操作する場合も、同じようにresult関数のみを操作すればよくなります。
Sign.result()では、leftとrightの値を計算し、その値を返す。
Number.result()では、自分自身の値を返す。
このように実装することで、result関数を実行する側は、子ノードが内部ノード、葉ノードでも結果がそのNode以下の合計の値が返ってきます。
そのため、内部ノードと葉ノードを意識せずに再帰処理を書くことが出来ます。
最後に
Compositeパターンのメリットは伝わりましたか?
最後まで読んでもらいありがとうございます!