在完成sparkMllib GMM算法例子之前需要知道幾個概念。1、高斯分布、2、多元高斯分布。3、高斯混合分布。4、協方差
GMM稱為混合高斯分布,它在單高斯分布(又稱正太分布,一維正太分布)的基礎上針對多元變量發展出來的。(以下參考了百度詞條内容)
1)單高斯分布公式:

,該公式的推導以及意義大家可以自行百度,這裡隻講一下各個參數在公式中的意義:
μ是正态分布的位置參數,描述正态分布的集中趨勢位置。機率規律為取與μ鄰近的值的機率大,而取離μ越遠的值的機率越小。正态分布以X=μ為對稱軸,左右完全對稱。正态分布的期望、均數、中位數、衆數相同,均等于μ。
σ描述正态分布資料資料分布的離散程度,σ越大,資料分布越分散,σ越小,資料分布越集中。也稱為是正态分布的形狀參數,σ越大,曲線越扁平,反之,σ越小,曲線越瘦高。
2)多元單高斯分布公式:
由上面的定義可知,多元單高斯分布的方差其實是協方差矩陣。
3)高斯混合分布:就是多個高斯分布(可能是單高斯也可能是多元高斯)的組合。下面是李航老師在《統計學習方法》
由上圖公式可知,高斯混合分布多了一個參數
,該參數就是每單高斯分布在高斯混合分布裡面的權重。
4)協方差矩陣的含義可以參考該篇博文:http://blog.csdn.net/yangdashi888/article/details/52397990
sparkMllib GMM算法就是根據一批給定的随機變量,每個随機變量肯能是一維的,也可能是多元的,然後求出高斯混合分布中的三個參數:
1、a權重。2、μ(如果是多元就是一個數組)3、方差(一維)/協方差矩陣(多元)
以下是sparkMllib GMM的例子。
1、資料gmm_data.txt中是二維資料,部分資料展示如下:
2.59470454e+00 2.12298217e+00
1.15807024e+00 -1.46498723e-01
2.46206638e+00 6.19556894e-01
-5.54845070e-01 -7.24700066e-01
-3.23111426e+00 -1.42579084e+00
2、資料gmm_data1.txt中是二維資料,部分資料展示如下:
2.59470454e+00
2.12298217e+00
1.15807024e+00
-1.46498723e-01
案例代碼如下:
package spark;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
// $example on$
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.clustering.GaussianMixture;
import org.apache.spark.mllib.clustering.GaussianMixtureModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
public class JavaGaussianMixtureExample {
public static void main(String[] args) {
Logger logger = Logger.getLogger(JavaGaussianMixtureExample.class);
Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF);
SparkConf conf = new SparkConf().setMaster("local[2]").setAppName("JavaGaussianMixtureExample");
JavaSparkContext jsc = new JavaSparkContext(conf);
String path = "F:/spark-2.1.0-bin-hadoop2.6/data/mllib/gmm_data1.txt";
JavaRDD<String> data = jsc.textFile(path);
JavaRDD<Vector> parsedData = data.map(f->{
return Vectors.dense(Double.parseDouble(f.trim()));
});
parsedData.cache();
/**
* k指定了高斯混合分布中的高斯分布個數。
*/
GaussianMixtureModel gmm = new GaussianMixture().setK(2).run(parsedData.rdd());
for (int j = 0; j < gmm.k(); j++) {
System.out.printf("一維混合高斯分布得到的資料如下:\nweight=%f\nmu=%s\nsigma=\n%s\n", gmm.weights()[j], gmm.gaussians()[j].mu(),
gmm.gaussians()[j].sigma());
}
logger.info("split line =====================================");
String path2 = "F:/spark-2.1.0-bin-hadoop2.6/data/mllib/gmm_data.txt";
JavaRDD<String> data2 = jsc.textFile(path2);
JavaRDD<Vector> parsedData2 = data2.map(s -> {
String[] sarray = s.trim().split(" ");
double[] values = new double[sarray.length];
for (int i = 0; i < sarray.length; i++) {
values[i] = Double.parseDouble(sarray[i]);
}
return Vectors.dense(values);
});
parsedData2.cache();
GaussianMixtureModel gmm2 = new GaussianMixture().setK(2).run(parsedData2.rdd());
for (int j = 0; j < gmm2.k(); j++) {
System.out.printf("二維混合高斯分布得到的資料如下:\nweight=%f\nmu=%s\nsigma=\n%s\n", gmm2.weights()[j], gmm2.gaussians()[j].mu(),
gmm2.gaussians()[j].sigma());
}
jsc.stop();
}
}
執行結果如下: