天天看點

用java實作K-均值聚類(k-means)

首先大家了解一下什麼是K-均值聚類,如下:

K均值聚類算法是先随機選取K個對象作為初始的聚類中心。然後計算每個對象與各個種子聚類中心之間的距離,把每個對象配置設定給距離它最近的聚類中心。聚類中心以及配置設定給它們的對象就代表一個聚類。一旦全部對象都被配置設定了,每個聚類的聚類中心會根據聚類中現有的對象被重新計算。這個過程将不斷重複直到滿足某個終止條件。終止條件可以是沒有(或最小數目)對象被重新配置設定給不同的聚類,沒有(或最小數目)聚類中心再發生變化,誤差平方和局部最小。

我們查閱資料了解到K-均值聚類的python代碼如下:

def distEclud(vecA, vecB):
    return sqrt(sum(power(vecA - vecB, 2))) #la.norm(vecA-vecB)

def randCent(dataSet, k):
    n = shape(dataSet)[1]
    centroids = mat(zeros((k,n)))#create centroid mat
    for j in range(n):#create random cluster centers, within bounds of each dimension
        minJ = min(dataSet[:,j]) 
        rangeJ = float(max(dataSet[:,j]) - minJ)
        centroids[:,j] = mat(minJ + rangeJ * random.rand(k,1))
    return centroids
    
def kMeans(dataSet, k, distMeas=distEclud, createCent=randCent):
    m = shape(dataSet)[0]
    clusterAssment = mat(zeros((m,2)))#create mat to assign data points 
                                      #to a centroid, also holds SE of each point
    centroids = createCent(dataSet, k)
    clusterChanged = True
    while clusterChanged:
        clusterChanged = False
        for i in range(m):#for each data point assign it to the closest centroid
            minDist = inf; minIndex = -1
            for j in range(k):
                distJI = distMeas(centroids[j,:],dataSet[i,:])
                if distJI < minDist:
                    minDist = distJI; minIndex = j
            if clusterAssment[i,0] != minIndex: clusterChanged = True
            clusterAssment[i,:] = minIndex,minDist**2
        print centroids
        for cent in range(k):#recalculate centroids
            ptsInClust = dataSet[nonzero(clusterAssment[:,0].A==cent)[0]]#get all the point in this cluster
            centroids[cent,:] = mean(ptsInClust, axis=0) #assign centroid to mean 
    return centroids, clusterAssment
           

我們用java代碼開始實作,首先是歐式幾何距離計算

private static double distEclud(DenseMatrix64F vecA,DenseMatrix64F vecB,int vecA_row,int vecB_row) {
		
		double rs=0;
		
		for(int i=0;i<vecA.numCols;i++) {
			rs+=Math.pow((vecA.get(vecA_row,i)-vecB.get(vecB_row,i)),2);
		}
		
		return Math.sqrt(rs);
		
	}
           

然後是簇的初始化

private static DenseMatrix64F randCent(DenseMatrix64F dataSet,int k) {
		DenseMatrix64F centroids = new DenseMatrix64F(k,dataSet.numCols);
		centroids.zero();
		
		for(int j=0;j<dataSet.numCols;j++) {
			double minJ = Double.MAX_VALUE;
			double maxJ = Double.MIN_VALUE;
			for(int i=0;i<dataSet.numRows;i++) {
				if(minJ > dataSet.get(i, j))
					minJ = dataSet.get(i, j);
				if(maxJ < dataSet.get(i, j))
					maxJ = dataSet.get(i, j);
			}
			double rangeJ = maxJ - minJ;
			
			for(int i=0;i<k;i++) {
				centroids.set(i, j, minJ + rangeJ * Math.random());
			}
		}
		
		return centroids;
	}
           

然後便是k-means的關鍵替代簇函數

public static DenseMatrix64F[] kMeans(DenseMatrix64F dataSet,int k) {
	    DenseMatrix64F clusterAssment = new DenseMatrix64F(dataSet.numRows, 2);
	    clusterAssment.zero();
	    DenseMatrix64F centroids = randCent(dataSet,k);
	    boolean clusterChanged = true;
	    while(clusterChanged) {	    	
	        clusterChanged = false;
	        
	        int changed = 0;
	        
	        for(int i=0;i<dataSet.numRows;i++) {
	            double minDist = Double.MAX_VALUE;
	            double minIndex = -1;
	            for(int j=0;j<k;j++) {
	                double distJI = distEclud(centroids,dataSet,j,i);
	                if(distJI < minDist) {
	                    minDist = distJI;
	                    minIndex = j;
	                }

	            }

	            if(clusterAssment.get(i, 0) != minIndex) {
	            	clusterChanged = true;
	            	changed++;
	            }
	            	
	            clusterAssment.set(i, 0, minIndex);
	            clusterAssment.set(i,1, minDist*minDist);
	        }
	        
	        System.out.println("變動點數:"+changed);
	        System.out.println(centroids);
	        
	        for(int cent=0;cent<k;cent++) {
	        	
	        	DenseMatrix64F tmp = new DenseMatrix64F(0,dataSet.numCols);
	        	
	        	for(int i=0;i<dataSet.numRows;i++) {
	        		if(clusterAssment.get(i, 0) == cent) {
	        			tmp.reshape(tmp.numRows+1, dataSet.numCols, true);
	        			for(int j=0;j<dataSet.numCols;j++) {
	        				tmp.set(tmp.numRows-1, j, dataSet.get(i, j));
	        			}
	        		}
	        	}
	        	
	        	
	        	
	        	if(tmp.numRows > 0) {
	        		
	        		for(int i=0;i<tmp.numCols;i++) {
	        			double tmpSum=0;
	        			for(int j=0;j<tmp.numRows;j++) {
	        				tmpSum+=dataSet.get(j, i);
	        			}
	        			centroids.set(cent,i,tmpSum/(tmp.numRows));
		        	}
		        	
	        	}


            }
	        
        	
        //對簇點進行排序
        	for(int i=0;i<centroids.numRows-1;i++) {
        		for(int j=i+1;j<centroids.numRows;j++) {
            		if(centroids.get(i, 0) > centroids.get(j, 0)) {
            			for(int n=0;n<centroids.numCols;n++) {
            				double tmp = centroids.get(j,n);
            				centroids.set(j, n,centroids.get(i, n));
            				centroids.set(i, n,tmp);
            			}
            			
            		}
            	}
        	}
	            
	    }

	    return new DenseMatrix64F[] {centroids, clusterAssment};
	}
           

這裡面有一個細節比較重要就是更新簇之後,簇得排序,否則做聚類運算的時候簇點會循環換位導緻跳不出循環。

開始測試

List<String> list = new ArrayList<String>();
        try{
            BufferedReader br = new BufferedReader(new FileReader("D:\\machinelearninginaction-master\\Ch10\\testSet2.txt"));
            String s = null;
            while((s = br.readLine())!=null){
            	list.add(s);
            }
            br.close();    
        }catch(Exception e){
            e.printStackTrace();
        }
        
        DenseMatrix64F dataMatIn = new DenseMatrix64F(list.size(),2);

        for(int i=0;i<list.size();i++) {
        	
        	String[] items = list.get(i).split("	");
        	dataMatIn.set(i, 0, Double.parseDouble(items[0]));
        	dataMatIn.set(i,1, Double.parseDouble(items[1]));
        }
       
        DenseMatrix64F[] test = kMeans(dataMatIn,4);
        
        System.out.println(test[0]);
        System.out.println(test[1]);
           

在多次循環後收斂

用java實作K-均值聚類(k-means)

ok搞定