Spark DecisionTreeModel print

软件版本:

   Spark:1.6.1 ; 

问题1:

在进行Spark DecisionTree建模时(做分类),可以打印决策树。当然,使用该模型的toDebugString 可以打印类似下面的字符串,例如:
DecisionTreeModel classifier of depth 7 with 45 nodes
  If (feature 22 <= 114.2)
   If (feature 27 <= 0.1108)
    If (feature 13 <= 45.19)
     If (feature 21 <= 32.84)
      Predict: 0.0
     Else (feature 21 > 32.84)
      If (feature 1 <= 22.61)
       If (feature 0 <= 11.49)
        Predict: 0.0
       Else (feature 0 > 11.49)
        Predict: 1.0
      Else (feature 1 > 22.61)
       Predict: 0.0
    Else (feature 13 > 45.19)
     If (feature 21 <= 22.13)
      Predict: 0.0
     Else (feature 21 > 22.13)
      If (feature 14 <= 0.004571)
       Predict: 0.0
      Else (feature 14 > 0.004571)
       Predict: 1.0
   Else (feature 27 > 0.1108)
    If (feature 21 <= 25.72)
     If (feature 24 <= 0.1786)
      If (feature 23 <= 809.7)
       Predict: 0.0
      Else (feature 23 > 809.7)
       If (feature 0 <= 14.02)
        Predict: 1.0
       Else (feature 0 > 14.02)
        Predict: 0.0
     Else (feature 24 > 0.1786)
      Predict: 1.0
    Else (feature 21 > 25.72)
     If (feature 7 <= 0.05266)
      If (feature 20 <= 15.5)
       Predict: 0.0
      Else (feature 20 > 15.5)
       If (feature 4 <= 0.09073)
        If (feature 10 <= 0.2406)
         Predict: 1.0
        Else (feature 10 > 0.2406)
         Predict: 0.0
       Else (feature 4 > 0.09073)
        Predict: 1.0
     Else (feature 7 > 0.05266)
      If (feature 12 <= 1.539)
       Predict: 0.0
      Else (feature 12 > 1.539)
       Predict: 1.0
  Else (feature 22 > 114.2)
   If (feature 27 <= 0.1397)
    If (feature 1 <= 14.96)
     Predict: 0.0
    Else (feature 1 > 14.96)
     If (feature 20 <= 18.79)
      If (feature 17 <= 0.009753)
       Predict: 1.0
      Else (feature 17 > 0.009753)
       If (feature 0 <= 17.3)
        Predict: 0.0
       Else (feature 0 > 17.3)
        Predict: 1.0
     Else (feature 20 > 18.79)
      Predict: 1.0
   Else (feature 27 > 0.1397)
    Predict: 1.0

但是,这样子看的其实不是很清楚,能否使用树形结构呢?

解决方法:

使用代码:
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.tree.model.Node
/**
 * 打印工具(可视化工具).
 */
object PrintUtils {

  /**
   * 打印决策树
   * @param model
   * @return
   */
  def printDecisionTree(model : DecisionTreeModel):String = {
    model.toString() + "\n" +
    printTree(model.topNode)
  }

  def printTree(root : Node) :String =  {
    val right:String = if (root.rightNode  != None) {
      printTree(root.rightNode.get, true, "")
    }else {
      ""
    }

    val rootStr = printNodeValue(root)
      val left :String= if (root.leftNode != None) {
      printTree(root.leftNode.get, false, "")
    }else {
        ""
      }
    right + rootStr + left
  }
  def printNodeValue(root :Node) :String= {
    val rootStr :String = if (root.split  == None) {
      if(root.isLeaf){
        root.predict.toString()
      }else{
        ""
      }
    } else {
      "Feature:"+root.split.get.feature+" > "+root.split.get.threshold
    }
    rootStr + "\n"
  }

 def printTree(root : Node,  isRight:Boolean ,  indent:String):String= {
    val right:String = if (root.rightNode != None) {
      printTree(root.rightNode.get, true, indent + (if(isRight)  "        " else " |      "))
    } else {
      ""
    }
//    indent
    val right2 = if (isRight) {
      " /"
    } else {
      " \\"
    }
    val tmp = "----- "
    val rootStr = printNodeValue(root)
      val left:String =     if (root.leftNode != None) {
      printTree(root.leftNode.get, false, indent + (if(isRight)  " |      " else "        "))
    }else {
        ""
      }
    right + indent + right2 + tmp + rootStr + left
  }

}

调用:
val modelPath = "..."
    val model = DecisionTreeModel.load(sc, modelPath)
        println(model.toDebugString)
    val str = PrintUtils.printDecisionTree(model)
    println(str)

打印得到:
DecisionTreeModel classifier of depth 7 with 45 nodes
         /----- 1.0 (prob = 1.0)
 /----- Feature:27 > 0.1397
 |       |               /----- 1.0 (prob = 1.0)
 |       |       /----- Feature:20 > 18.79
 |       |       |       |               /----- 1.0 (prob = 1.0)
 |       |       |       |       /----- Feature:0 > 17.3
 |       |       |       |       |       \----- 0.0 (prob = 1.0)
 |       |       |       \----- Feature:17 > 0.009753
 |       |       |               \----- 1.0 (prob = 1.0)
 |       \----- Feature:1 > 14.96
 |               \----- 0.0 (prob = 1.0)
Feature:22 > 114.2
 |                               /----- 1.0 (prob = 1.0)
 |                       /----- Feature:12 > 1.539
 |                       |       \----- 0.0 (prob = 1.0)
 |               /----- Feature:7 > 0.05266
 |               |       |               /----- 1.0 (prob = 1.0)
 |               |       |       /----- Feature:4 > 0.09073
 |               |       |       |       |       /----- 0.0 (prob = 1.0)
 |               |       |       |       \----- Feature:10 > 0.2406
 |               |       |       |               \----- 1.0 (prob = 1.0)
 |               |       \----- Feature:20 > 15.5
 |               |               \----- 0.0 (prob = 1.0)
 |       /----- Feature:21 > 25.72
 |       |       |       /----- 1.0 (prob = 1.0)
 |       |       \----- Feature:24 > 0.1786
 |       |               |               /----- 0.0 (prob = 1.0)
 |       |               |       /----- Feature:0 > 14.02
 |       |               |       |       \----- 1.0 (prob = 1.0)
 |       |               \----- Feature:23 > 809.7
 |       |                       \----- 0.0 (prob = 1.0)
 \----- Feature:27 > 0.1108
         |                       /----- 1.0 (prob = 1.0)
         |               /----- Feature:14 > 0.004571
         |               |       \----- 0.0 (prob = 1.0)
         |       /----- Feature:21 > 22.13
         |       |       \----- 0.0 (prob = 1.0)
         \----- Feature:13 > 45.19
                 |               /----- 0.0 (prob = 1.0)
                 |       /----- Feature:1 > 22.61
                 |       |       |       /----- 1.0 (prob = 1.0)
                 |       |       \----- Feature:0 > 11.49
                 |       |               \----- 0.0 (prob = 1.0)
                 \----- Feature:21 > 32.84
                         \----- 0.0 (prob = 1.0)

这样看起来更加清楚点!

问题2:

如果DecisionTreeModel里面的节点是离散型的呢?
1. DecisionTreeModel里面定义的Node是可以是离散型的;
2. 在进行建模时,train方法里面的使用的是RDD[LabeledPoint]:

所以每个节点只能是连续的,不存在离散的;
3. 猜测,应该是后面的Spark版本会有对应的支持;


分享,成长,快乐

脚踏实地,专注

转载请注明blog地址:http://blog.csdn.NET/fansy1990


你可能感兴趣的:(spark,decisiontree)