Spark-MLlib執行個體——決策樹
通俗來說,決策樹分類的思想類似于找對象。現想象一個女孩的母親要給這個女孩介紹男朋友,于是有了下面的對話:
女兒:多大年紀了?
母親:26。
女兒:長的帥不帥?
母親:挺帥的。
女兒:收入高不?
母親:不算很高,中等情況。
女兒:是公務員不?
母親:是,在稅務局上班呢。
女兒:那好,我去見見。
![](https://img.laitimes.com/img/_0nNw4CM6IyYiwiM6ICdiwiIyVGduV2QvwVe0lmdhJ3ZvwFM38CXlZHbvN3cpR2Lc1TPB10QGtWUCpEMJ9CXsxWam9CXwADNvwVZ6l2c052bm9CXUJDT1wkNhVzLcRnbvZ2LcZXUYpVd1kmYr50MZV3YyI2cKJDT29GRjBjUIF2LcRHelR3LcJzLctmch1mclRXY39DM0MjNwIDMzIjNycDM2EDMy8CX0Vmbu4GZzNmLn9Gbi1yZtl2Lc9CX6MHc0RHaiojIsJye.jpg)
以上是決策的經典例子,用spark-mllib怎麼實作訓練與預測呢
1、首先準備測試資料集
訓練資料集 Tree1
字段說明:
是否見面, 年齡 是否帥 收入(1 高 2 中等 0 少) 是否公務員
0,32 1 1 0
0,25 1 2 0
1,29 1 2 1
1,24 1 1 0
0,31 1 1 0
1,35 1 2 1
0,30 0 1 0
0,31 1 1 0
1,30 1 2 1
1,21 1 1 0
0,21 1 2 0
1,21 1 2 1
0,29 0 2 1
0,29 1 0 1
0,29 0 2 1
1,30 1 1 0
測試資料集 Tree2
0,32 1 2 0
1,27 1 1 1
1,29 1 1 0
1,25 1 2 1
0,23 0 2 1
2、Spark-MLlib決策樹應用代碼
import org.apache.log4j.{Level, Logger}
import org.apache.spark.mllib.feature.HashingTF
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.{SparkConf, SparkContext}
/**
* 決策樹分類
*/
object TreeDemo {
def main(args: Array[String]) {
val conf = new SparkConf().setAppName("DecisionTree").setMaster("local")
val sc = new SparkContext(conf)
Logger.getRootLogger.setLevel(Level.WARN)
//訓練資料
val data1 = sc.textFile("data/Tree1.txt")
//測試資料
val data2 = sc.textFile("data/Tree2.txt")
//轉換成向量
val tree1 = data1.map { line =>
val parts = line.split(',')
LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
}
val tree2 = data2.map { line =>
val parts = line.split(',')
LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
}
//指派
val (trainingData, testData) = (tree1, tree2)
//分類
val numClasses = 2
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "gini"
//最大深度
val maxDepth = 5
//最大分支
val maxBins = 32
//模型訓練
val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
impurity, maxDepth, maxBins)
//模型預測
val labelAndPreds = testData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
//測試值與真實值對比
val print_predict = labelAndPreds.take(15)
println("label" + "\t" + "prediction")
for (i <- 0 to print_predict.length - 1) {
println(print_predict(i)._1 + "\t" + print_predict(i)._2)
}
//樹的錯誤率
val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
println("Test Error = " + testErr)
//列印樹的判斷值
println("Learned classification tree model:\n" + model.toDebugString)
}
}
3、測試結果:
label prediction
0.0 0.0
1.0 1.0
1.0 1.0
1.0 1.0
0.0 0.0
Test Error = 0.0
Learned classification tree model:
可見真實值與預測值一緻,Error為0
列印決策樹的分支值,這裡最大深度為 5 ,對應的樹結構:
Learned classification tree model:
DecisionTreeModel classifier of depth 4 with 11 nodes
If (feature 1 <= 0.0)
Predict: 0.0
Else (feature 1 > 0.0)
If (feature 3 <= 0.0)
If (feature 0 <= 30.0)
If (feature 2 <= 1.0)
Predict: 1.0
Else (feature 2 > 1.0)
Predict: 0.0
Else (feature 0 > 30.0)
Predict: 0.0
Else (feature 3 > 0.0)
If (feature 2 <= 0.0)
Predict: 0.0
Else (feature 2 > 0.0)
Predict: 1.0
可見預測出的分界值與真實一緻,準确率與決策樹算法,參數設定及訓練樣本的選擇覆寫有關!