天天看點

spark mllib 決策樹算法

該樣例取自spark進階資料分析第四章的内容,資料集來自http://t.cn/ R2wmIsI,包含一個 CSV 格式的壓縮資料檔案 covtype.data.gz,附帶一個描述資料檔案的 資訊檔案 covtype.info

spark mllib将特征向量抽象為LabeledPoint,它由一個包含多個特征值的Spark MLlib Vector 和一個稱為标号(label)的目标值組成。該目标為 Double 類型,而 Vector 本質上 是對多個 Double 類型值的抽象。這說明 LabeledPoint 隻适用于數值型特征。但隻要經過 适當編碼,LabeledPoint 也可用于類别型特征。另一個就是對于非數值類特征取one-hot編碼,例如天氣分為晴天、陰天、下雨,今天天氣晴朗則100見代碼:

package com.demo.rdf

import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, MulticlassMetrics}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.tree.{GradientBoostedTrees, RandomForest, DecisionTree}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel, DecisionTreeModel}
import org.apache.spark.rdd.RDD
import org.apache.spark
import org.apache.spark.{SparkConf, SparkContext}
/**
  * Created by leslie on 16/10/26.
  */
object RunRDFs {
  def main(args: Array[String]) {
    val conf = new SparkConf().setAppName("DecisionTree")
    val sc = new SparkContext(conf)
    val rawData = sc.textFile("/zhenfei1/covtype/covtype.data")

    val data = rawData.map{line=>
      val values = line.split(",").map(_.toDouble)
      val featureVector = Vectors.dense(values.init)
      val label = values.last-
      LabeledPoint(label,featureVector)
    }
    val Array(trainData,cvData,testData) = data.randomSplit(Array(,,))
    trainData.cache();cvData.cache();testData.cache()
    simpleDecisionTree(trainData,cvData)
    randomClassifier(trainData,cvData)
    evaluate(trainData,cvData,testData)
  }

  /**
    * 建立一個普通的決策樹
    *
    * @param trainData
    * @param cvData
    */
  def simpleDecisionTree(trainData:RDD[LabeledPoint],cvData:RDD[LabeledPoint])={
    val model = DecisionTree.trainClassifier(trainData,,Map[Int,Int](),"entropy",,)
    val metrics = getMetrics(model,cvData)
    println("confusionMatrix : "+metrics.confusionMatrix)
    println("precision"+metrics.precision)
    ( until ).map(line=>
      ((metrics.precision(line),metrics.recall(line)))
    ).foreach(println)
  }

  /**
    * 随即森林模型
    *
    * @param trainData
    * @param cvData
    */
  def randomClassifier(trainData:RDD[LabeledPoint],cvData:RDD[LabeledPoint])={
    val trainPriorProbabilities = classProbabilities(trainData)
    val cvPriorProbabilities = classProbabilities(cvData)
    val accurry = trainPriorProbabilities.zip(cvPriorProbabilities).map{
      case (trainPro,cvPro)=> trainPro*cvPro
    }.sum
    println(accurry)
  }
  def GBDTClassifier(trainData:RDD[LabeledPoint],cvData:RDD[LabeledPoint])={
    val boostingStrategy = BoostingStrategy.defaultParams("Classification")

    val evaluations=
      for(depth<-Array(,);bins<-Array(,)) yield {

        val boostingStrategy = BoostingStrategy.defaultParams("Classification")
        boostingStrategy.setNumIterations()//疊代次數
        boostingStrategy.treeStrategy.setNumClasses()//分類數目
        boostingStrategy.treeStrategy.setMaxDepth(depth)//決策樹最高層
//        boostingStrategy.treeStrategy.setImpurity(impurity)
        boostingStrategy.treeStrategy.setCategoricalFeaturesInfo(Map[Int,Int]())

//        val model = DecisionTree.trainClassifier(trainData,7,Map[Int,Int](),impurity,depth,bins)
        val model = GradientBoostedTrees.train(trainData,boostingStrategy) //.train(trainData,7,Map[Int,Int](),impurity,depth,bins)

        val trainAccuracy = getMetrics(model,trainData).precision

//        val pricision = getMetrics(model,cvData).precision
        ((depth,bins),trainAccuracy)
      }

  }
  /**
    * 決策樹調優 采用不同參數進行測驗
    *
    * @param trainData
    * @param cvData
    * @param testData
    */
  def evaluate(trainData:RDD[LabeledPoint],cvData:RDD[LabeledPoint],testData:RDD[LabeledPoint])={
    val evaluations=
      for(impurity<-Array("gini","entropy");depth<-Array(,);bins<-Array(,)) yield {
        val model = DecisionTree.trainClassifier(trainData,,Map[Int,Int](),impurity,depth,bins)
        val pricision = getMetrics(model,cvData).precision
        ((impurity,depth,bins),pricision)
      }
    evaluations.sortBy(_._2).reverse.foreach(println)
    val model = DecisionTree.trainClassifier(trainData.union(cvData),,Map[Int,Int](),"entropy",,)
    println(getMetrics(model,testData).precision)
    println(getMetrics(model,trainData.union(cvData)).precision)
  }

  /**
    * 資料中類别型特征使用one-hot編碼,這種編碼迫使決策樹算法在底層要單獨考慮類别型特征的每一個值,
    * 增加記憶體使用量并且減慢決策速度。我們取消one-hot編碼:
    *
    * @param rawdata
    * @return
    */
  def unencodeOneHot(rawdata:RDD[String]):RDD[LabeledPoint]={
    rawdata.map{line=>
      val values = line.split(',').map(_.toDouble)
      val wilderness = values.slice(,).indexOf().toDouble
      val soil = values.slice(,).indexOf().toDouble
      val featureVector = Vectors.dense(values.slice(,):+wilderness:+soil)
      val label = values.last-
      LabeledPoint(label,featureVector)
    }
  }

  /**
    * 多組資料測試檢驗
    *
    * @param rawdata
    */
  def evaluateCategorical(rawdata:RDD[String])={
    val data = unencodeOneHot(rawdata)
    val Array(trainData,cvData,testData) = data.randomSplit(Array(,,))
    trainData.cache();cvData.cache();testData.cache();
    val evaluations = for(impurity<-Array("gini","entropy");depth<-Array(,,);bins<-Array(,)) yield{
      val model = DecisionTree.trainClassifier(trainData,,Map(->,->),impurity,depth,bins)
      val trainAccuray = getMetrics(model,trainData).precision
      val cvAccuray = getMetrics(model ,cvData).precision
      ((impurity,depth,bins),(trainAccuray,cvAccuray))
    }
    evaluations.sortBy(_._2._2).reverse.foreach(println)
//    val model = DecisionTree.trainClassifier(
//      trainData.union(cvData), 7, Map(10 -> 4, 11 -> 40), "entropy", 30, 300)
//    println(getMetrics(model, testData).precision)
//
//    trainData.unpersist()
//    cvData.unpersist()
//    testData.unpersist()
  }
  def testCategorical(rawData:RDD[String])={
    val data = unencodeOneHot(rawData)
    val Array(trainData,cvData,testData) = data.randomSplit(Array(,,))
    trainData.cache();cvData.cache();testData.cache()
    val model = DecisionTree.trainClassifier(trainData.union(cvData),,Map[Int,Int](),"entropy",,)
    println(getMetrics(model,testData).precision)
  }

  /**
    *
    * @param rawData
    */
  def evaluateForest(rawData:RDD[String])={
    val data = unencodeOneHot(rawData)
    val Array(trainData,cvData,testData) = data.randomSplit(Array(,,))
    trainData.cache();cvData.cache();testData.cache()
    val forest = RandomForest.trainClassifier(
      trainData, , Map( -> ,  -> ), , "auto", "entropy", , )
    val pridictionsAndLabels = cvData.map(example=>
      (forest.predict(example.features),example.label)
    )
    println(new MulticlassMetrics(pridictionsAndLabels).precision)
    val input = "2709,125,28,67,23,3224,253,207,61,6094,0,29"
    val vector = Vectors.dense(input.split(",").map(_.toDouble))
    println(forest.predict(vector))
  }
  def classProbabilities(data:RDD[LabeledPoint]):Array[Double]={
    val countByCategory = data.map(_.label).countByValue()
    val counts = countByCategory.toArray.sortBy(_._1).map(_._2)
    counts.map(_.toDouble/counts.sum)
  }

  /**
    * 個性化push的模型評估
    * @param model
    * @param data
    * @return
    */
  def getMetric(model:GradientBoostedTreesModel,data:RDD[LabeledPoint]):BinaryClassificationMetrics={
    val predicitAndLabels = data.map{line=>
      (model.predict(line.features),line.label)
    }
    new BinaryClassificationMetrics(predicitAndLabels)
  }
  def getMetrics(model:GradientBoostedTreesModel,data:RDD[LabeledPoint]):MulticlassMetrics={
    val predicitionAndlabel = data.map{line=>
      (model.predict(line.features),line.label)
    }
    new BinaryClassificationMetrics(predicitionAndlabel)
    new MulticlassMetrics(predicitionAndlabel)
  }
  def getMetrics(model:DecisionTreeModel,data:RDD[LabeledPoint]):MulticlassMetrics={
    val pridictionAndlabels = data.map{line=>
      (model.predict(line.features),line.label)
    }
//    new (pridictionAndlabels)
    new MulticlassMetrics(pridictionAndlabels)
  }

}