天天看點

Spark-MLlib執行個體——決策樹

Spark-MLlib執行個體——決策樹

通俗來說,決策樹分類的思想類似于找對象。現想象一個女孩的母親要給這個女孩介紹男朋友,于是有了下面的對話:

女兒:多大年紀了?
母親:26。
女兒:長的帥不帥?
母親:挺帥的。
女兒:收入高不?
母親:不算很高,中等情況。
女兒:是公務員不?
母親:是,在稅務局上班呢。
女兒:那好,我去見見。
           
Spark-MLlib執行個體——決策樹

以上是決策的經典例子,用spark-mllib怎麼實作訓練與預測呢

1、首先準備測試資料集

訓練資料集 Tree1

字段說明:

是否見面, 年齡  是否帥  收入(1 高 2 中等 0 少)  是否公務員

0,32 1 1 0
0,25 1 2 0
1,29 1 2 1
1,24 1 1 0
0,31 1 1 0
1,35 1 2 1
0,30 0 1 0
0,31 1 1 0
1,30 1 2 1
1,21 1 1 0
0,21 1 2 0
1,21 1 2 1
0,29 0 2 1
0,29 1 0 1
0,29 0 2 1
1,30 1 1 0
           

測試資料集 Tree2

0,32 1 2 0
1,27 1 1 1
1,29 1 1 0
1,25 1 2 1
0,23 0 2 1
           

2、Spark-MLlib決策樹應用代碼

import org.apache.log4j.{Level, Logger}
import org.apache.spark.mllib.feature.HashingTF
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.{SparkConf, SparkContext}

/**
  * 決策樹分類
  */
object TreeDemo {

  def main(args: Array[String]) {

    val conf = new SparkConf().setAppName("DecisionTree").setMaster("local")
    val sc = new SparkContext(conf)
    Logger.getRootLogger.setLevel(Level.WARN)

    //訓練資料
    val data1 = sc.textFile("data/Tree1.txt")

    //測試資料
    val data2 = sc.textFile("data/Tree2.txt")


    //轉換成向量
    val tree1 = data1.map { line =>
      val parts = line.split(',')
      LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
    }

    val tree2 = data2.map { line =>
      val parts = line.split(',')
      LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
    }

    //指派
    val (trainingData, testData) = (tree1, tree2)

    //分類
    val numClasses = 2
    val categoricalFeaturesInfo = Map[Int, Int]()
    val impurity = "gini"

    //最大深度
    val maxDepth = 5
    //最大分支
    val maxBins = 32

    //模型訓練
    val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
      impurity, maxDepth, maxBins)

    //模型預測
    val labelAndPreds = testData.map { point =>
      val prediction = model.predict(point.features)
      (point.label, prediction)
    }

    //測試值與真實值對比
    val print_predict = labelAndPreds.take(15)
    println("label" + "\t" + "prediction")
    for (i <- 0 to print_predict.length - 1) {
      println(print_predict(i)._1 + "\t" + print_predict(i)._2)
    }

    //樹的錯誤率
    val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
    println("Test Error = " + testErr)
    //列印樹的判斷值
    println("Learned classification tree model:\n" + model.toDebugString)

  }

}
           

3、測試結果:

label	prediction
0.0	0.0
1.0	1.0
1.0	1.0
1.0	1.0
0.0	0.0
Test Error = 0.0
Learned classification tree model:
           

可見真實值與預測值一緻,Error為0

列印決策樹的分支值,這裡最大深度為 5 ,對應的樹結構:

Learned classification tree model:
DecisionTreeModel classifier of depth 4 with 11 nodes
  If (feature 1 <= 0.0)
   Predict: 0.0
  Else (feature 1 > 0.0)
   If (feature 3 <= 0.0)
    If (feature 0 <= 30.0)
     If (feature 2 <= 1.0)
      Predict: 1.0
     Else (feature 2 > 1.0)
      Predict: 0.0
    Else (feature 0 > 30.0)
     Predict: 0.0
   Else (feature 3 > 0.0)
    If (feature 2 <= 0.0)
     Predict: 0.0
    Else (feature 2 > 0.0)
     Predict: 1.0
           

可見預測出的分界值與真實一緻,準确率與決策樹算法,參數設定及訓練樣本的選擇覆寫有關!

繼續閱讀