決定木はそのロジック的にモデルの解釈がしやすくて、意外と現実世界で利用されることが多いような気がします。
そんなわけで、Spark MLlib (ML Pipeline) の 決定木 のモデルについて、その構造を標準出力する機能を作ったのでメモとして残して置きます。
package biz.k11i.spark.misc
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.tree.{CategoricalSplit, ContinuousSplit, InternalNode, LeafNode}
/**
* spark.ml の DecisionTreeClassificationModel の木構造を標準出力に書き出す。
*/
object DecisionTreePrinter {
def printTree(model: DecisionTreeClassificationModel): Unit = {
model.rootNode match {
case node: InternalNode => printNodes(node, 0)
case leaf: LeafNode => printLeaf(leaf, 0)
}
}
def printNodes(node: InternalNode, numIndents: Int): Unit = {
val indents = " " * numIndents
node.split match {
case cat: CategoricalSplit => println(s"${indents}category, featureIndex=${cat.featureIndex}, left=${cat.leftCategories.mkString(",")}, right=${cat.rightCategories.mkString(",")}")
case con: ContinuousSplit => println(s"${indents}continuous, featureIndex=${con.featureIndex}, threshold=${con.threshold}")
}
Seq(node.leftChild, node.rightChild).foreach {
case internalNode: InternalNode => printNodes(internalNode, numIndents + 1)
case leafNode: LeafNode => printLeaf(leafNode, numIndents + 1)
}
}
def printLeaf(node: LeafNode, indent: Int): Unit = {
val indents = " " * indent
println(s"${indents}prediction=${node.prediction}")
}
}
これにちょこっと手を加えれば、Random forest のモデルも Gradient-boosted trees のモデルも同様にダンプできると思うけど、面倒くさいのでまた今度。