天天看點

Spark——SVM代碼及注解

一、簡介

SVM(支援向量機)把分類問題轉化為尋找分類平面的問題,并通過最大化分類邊界點距離分類平面的距離來實作分類。

Spark——SVM代碼及注解

二、示例

1、資料

PS:以下是一部分,檔案名為sample_svm_data.txt,下載下傳位址:機器學習檔案資料包。

2、代碼

package com.svm

import org.apache.spark.mllib.classification.SVMWithSGD
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.{SparkConf, SparkContext}

object TestSVMDemo {
    def main(args: Array[String]): Unit = {
        val sc = new SparkContext(new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName))
        val file1 = sc.textFile("src/main/resources/svm/sample_svm_data.txt")

        // 讀取資料
        val file = file1.map { line =>
            val strs = line.split(" ")
            val label = strs(0)
            val features = for (i <- 1 until strs.length) yield strs(i)
            LabeledPoint(label.toDouble, Vectors.dense(features.map(_.toDouble).toArray))
        }

        // 劃分訓練和測試資料
        val array = file.randomSplit(Array(0.8, 0.2), 5)

        // 建立模型并訓練
        val numIterations = 120
        val model = SVMWithSGD.train(array(0), numIterations)

        val predictions = array(1).map { test =>
            val score = model.predict(test.features)
            (score, test.label)
        }

        val showData = predictions.take(10)
        for (i <- showData.indices) {
            println(showData(i)._1 + "\t" + showData(i)._2 + "\t")
        }

        //擷取評價名額
        val metrics = new BinaryClassificationMetrics(predictions)
        val auROC = metrics.areaUnderROC()
        println(s"準确率(ROC): $auROC")

    }
}
           

繼續閱讀