天天看點

xgboost之spark上運作-scala接口概述添加依賴:RDD接口:DataFrame接口:

概述

xgboost可以在spark上運作,我用的xgboost的版本是0.7的版本,目前隻支援spark2.0以上版本上運作,

編譯好jar包,加載到maven倉庫裡面去:

  1. mvn install:install-file -Dfile=xgboost4j-spark-0.7-jar-with-dependencies.jar -DgroupId=ml.dmlc -DartifactId=xgboost4j-spark -Dversion=0.7 -Dpackaging=jar

添加依賴:

[html]  view plain  copy

  1. <dependency>  
  2.             <groupId>ml.dmlc</groupId>  
  3.             <artifactId>xgboost4j-spark</artifactId>  
  4.             <version>0.7</version>  
  5.         </dependency>  
  6.         <dependency>  
  7.             <groupId>org.apache.spark</groupId>  
  8.             <artifactId>spark-core_2.10</artifactId>  
  9.             <version>2.0.0</version>  
  10.         </dependency>  
  11.         <dependency>  
  12.             <groupId>org.apache.spark</groupId>  
  13.             <artifactId>spark-mllib_2.10</artifactId>  
  14.             <version>2.0.0</version>  
  15.         </dependency>  
  16.     </dependencies>  

RDD接口:

[python]  view plain  copy

  1. package com.meituan.spark_xgboost  
  2. import org.apache.log4j.{ Level, Logger }  
  3. import org.apache.spark.{ SparkConf, SparkContext }  
  4. import ml.dmlc.xgboost4j.scala.spark.XGBoost  
  5. import org.apache.spark.sql.{ SparkSession, Row }  
  6. import org.apache.spark.mllib.util.MLUtils  
  7. import org.apache.spark.ml.feature.LabeledPoint  
  8. import org.apache.spark.ml.linalg.Vectors  
  9. object XgboostR {  
  10.   def main(args: Array[String]): Unit = {  
  11.     Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)  
  12.     Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)  
  13.     val spark = SparkSession.builder.master("local").appName("example").  
  14.       config("spark.sql.warehouse.dir", s"file:///Users/shuubiasahi/Documents/spark-warehouse").  
  15.       config("spark.sql.shuffle.partitions", "20").getOrCreate()  
  16.     spark.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")  
  17.       val path = "/Users/shuubiasahi/Documents/workspace/xgboost/demo/data/"  
  18.   val trainString = "agaricus.txt.train"  
  19.   val testString = "agaricus.txt.test"  
  20.     val train = MLUtils.loadLibSVMFile(spark.sparkContext, path + trainString)  
  21.     val test = MLUtils.loadLibSVMFile(spark.sparkContext, path + testString)  
  22.     val traindata = train.map { x =>  
  23.       val f = x.features.toArray  
  24.       val v = x.label  
  25.       LabeledPoint(v, Vectors.dense(f))  
  26.     }  
  27.     val testdata = test.map { x =>  
  28.       val f = x.features.toArray  
  29.       val v = x.label  
  30.        Vectors.dense(f)  
  31.     }  
  32.     val numRound = 15  
  33.      //"objective" -> "reg:linear", //定義學習任務及相應的學習目标  
  34.       //"eval_metric" -> "rmse", //校驗資料所需要的評價名額  用于做回歸  
  35.     val paramMap = List(  
  36.       "eta" -> 1f,  
  37.       "max_depth" ->5, //數的最大深度。預設值為6 ,取值範圍為:[1,∞]   
  38.       "silent" -> 1, //取0時表示列印出運作時資訊,取1時表示以緘默方式運作,不列印運作時資訊。預設值為0   
  39.       "objective" -> "binary:logistic", //定義學習任務及相應的學習目标  
  40.       "lambda"->2.5,  
  41.       "nthread" -> 1 //XGBoost運作時的線程數。預設值是目前系統可以獲得的最大線程數  
  42.       ).toMap  
  43.     println(paramMap)  
  44.     val model = XGBoost.trainWithRDD(traindata, paramMap, numRound, 55, null, null, useExternalMemory = false, Float.NaN)  
  45.     print("sucess")  
  46.     val result=model.predict(testdata)  
  47.     result.take(10).foreach(println)  
  48.     spark.stop();  
  49.   }  
  50. }  

DataFrame接口:

[python]  view plain  copy

  1. package com.meituan.spark_xgboost  
  2. import org.apache.log4j.{ Level, Logger }  
  3. import org.apache.spark.{ SparkConf, SparkContext }  
  4. import ml.dmlc.xgboost4j.scala.spark.XGBoost  
  5. import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics  
  6. import org.apache.spark.sql.{ SparkSession, Row }  
  7. object XgboostD {  
  8.   def main(args: Array[String]): Unit = {  
  9.     Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)  
  10.     Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)  
  11.     val spark = SparkSession.builder.master("local").appName("example").  
  12.       config("spark.sql.warehouse.dir", s"file:///Users/shuubiasahi/Documents/spark-warehouse").  
  13.       config("spark.sql.shuffle.partitions", "20").getOrCreate()  
  14.     spark.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")  
  15.     val path = "/Users/shuubiasahi/Documents/workspace/xgboost/demo/data/"  
  16.     val trainString = "agaricus.txt.train"  
  17.     val testString = "agaricus.txt.test"  
  18.     val train = spark.read.format("libsvm").load(path + trainString).toDF("label", "feature")  
  19.     val test = spark.read.format("libsvm").load(path + testString).toDF("label", "feature")  
  20.     val numRound = 15  
  21.     //"objective" -> "reg:linear", //定義學習任務及相應的學習目标  
  22.     //"eval_metric" -> "rmse", //校驗資料所需要的評價名額  用于做回歸  
  23.     val paramMap = List(  
  24.       "eta" -> 1f,  
  25.       "max_depth" -> 5, //數的最大深度。預設值為6 ,取值範圍為:[1,∞]   
  26.       "silent" -> 1, //取0時表示列印出運作時資訊,取1時表示以緘默方式運作,不列印運作時資訊。預設值為0   
  27.       "objective" -> "binary:logistic", //定義學習任務及相應的學習目标  
  28.       "lambda" -> 2.5,  
  29.       "nthread" -> 1 //XGBoost運作時的線程數。預設值是目前系統可以獲得的最大線程數  
  30.       ).toMap  
  31.     val model = XGBoost.trainWithDataFrame(train, paramMap, numRound, 45, obj = null, eval = null, useExternalMemory = false, Float.NaN, "feature", "label")  
  32.     val predict = model.transform(test)  
  33.     val scoreAndLabels = predict.select(model.getPredictionCol, model.getLabelCol)  
  34.       .rdd  
  35.       .map { case Row(score: Double, label: Double) => (score, label) }  
  36.     //get the auc  
  37.     val metric = new BinaryClassificationMetrics(scoreAndLabels)  
  38.     val auc = metric.areaUnderROC()  
  39.     println("auc:" + auc)  
  40.   }  
  41. }  

繼續閱讀