天天看點

【Java】K-means算法Java實作以及圖像切割

1.K-means算法簡述以及代碼原型

資料挖掘中一個重要算法是K-means。我這裡就不做具體介紹。假設感興趣的話能夠移步陳皓的部落格:   

 http://www.csdn.net/article/2012-07-03/2807073-k-means 講得非常好

    總的來講,k-means聚類須要下面幾個步驟:

         ①.初始化資料

         ②.計算初始的中心點,能夠随機選擇

         ③.計算每一個點到每一個聚類中心的距離。而且劃分到距離最短的聚類中心簇中

         ④.計算每一個聚類簇的平均值,這個均值作為新的聚類中心,反複步驟3

         ⑤.假設達到最大循環或者是聚類中心不再變化或者聚類中心變化幅度小于一定範圍時,停止循環。

    恩。原理就是這樣,超級簡單。可是Java算法實作起來代碼量并不小。這個代碼也不算是全然自己寫的啦。也有些借鑒。我把k-means實作封裝在了一個類裡面,這樣就能夠随時調用了呢。

import java.util.ArrayList;
import java.util.Random;

public class kmeans {
	private int k;//簇數
	private int m;//疊代次數
	private int dataSetLength;//資料集長度
	private ArrayList<double[]> dataSet;//資料集合
	private ArrayList<double[]> center;//中心連結清單
	private ArrayList<ArrayList<double[]>> cluster;//簇
	private ArrayList<Float> jc;//誤差平方和,這個是用來計算中心聚點的移動哦
	private Random random;
	
	//設定原始資料集合
	public void setDataSet(ArrayList<double[]> dataSet){
		this.dataSet=dataSet;
	}
	//獲得簇分組
	public  ArrayList<ArrayList<double[]>> getCluster(){
		return this.cluster;
	}
	//構造函數,傳入要分的簇的數量
	public kmeans(int k){
		if(k<=0)
			k=1;
		this.k=k;
	}
	//初始化
	private void init(){
		m=0;
		random=new Random();
		if(dataSet==null||dataSet.size()==0)
			initDataSet();
		dataSetLength=dataSet.size();
		if(k>dataSetLength)
			k=dataSetLength;
		center=initCenters();
		cluster=initCluster();
		jc=new ArrayList<Float>();
	}
	//初始化資料集合
	private void initDataSet(){
		dataSet=new ArrayList<double[]>();
		double[][] dataSetArray=new double[][]{{8,2},{3,4},{2,5},{4,2},
				{7,3},{6,2},{4,7},{6,3},{5,3},{6,3},{6,9},
				{1,6},{3,9},{4,1},{8,6}};
		for(int i=0;i<dataSetArray.length;i++)
			dataSet.add(dataSetArray[i]);
	}
	//初始化中心連結清單,分成幾簇就有幾個中心
	private ArrayList<double[]> initCenters(){
		ArrayList<double[]> center= new ArrayList<double[]>();
		//生成一個随機數列。
		int[] randoms=new int[k];
		boolean flag;
		int temp=random.nextInt(dataSetLength);
		randoms[0]=temp;
		for(int i=1;i<k;i++){
			flag=true;
			while(flag){
				temp=random.nextInt(dataSetLength);
				int j=0;
				while(j<i){
					if(temp==randoms[j])
						break;
					j++;
				}
				if(j==i)
					flag=false;
			}
			randoms[i]=temp;
		}
		for(int i=0;i<k;i++)
			center.add(dataSet.get(randoms[i]));
		return center;
	}
	//初始化簇集合
	private ArrayList<ArrayList<double[]>> initCluster(){
		ArrayList<ArrayList<double[]>> cluster=
				new ArrayList<ArrayList<double[]>>();
		for(int i=0;i<k;i++)
			cluster.add(new ArrayList<double[]>());
		return cluster;
	}
	//計算距離
	private double distance(double[] element,double[] center){
		double distance=0.0f;
		double x=element[0]-center[0];
		double y=element[1]-center[1];
		double z=element[2]-center[2];
		double sum=x*x+y*y+z*z;
		distance=(double)Math.sqrt(sum);
		return distance;
	}
	//計算最短的距離
	private int minDistance(double[] distance){
		double minDistance=distance[0];
		int minLocation=0;
		for(int i=0;i<distance.length;i++){
			if(distance[i]<minDistance){
				minDistance=distance[i];
				minLocation=i;
			}else if(distance[i]==minDistance){
				if(random.nextInt(10)<5){
					minLocation=i;
				}
			}
		}
		return minLocation;
	}
	//每一個點分類
	private void clusterSet(){
		double[] distance=new double[k];
		for(int i=0;i<dataSetLength;i++){
			//計算到每一個中心店的距離
			for(int j=0;j<k;j++)
				distance[j]=distance(dataSet.get(i),center.get(j));
			//計算最短的距離
			int minLocation=minDistance(distance);
			//把他加到聚類裡
			cluster.get(minLocation).add(dataSet.get(i));
		}
	}
	//計算新的中心
	private void setNewCenter(){
		for(int i=0;i<k;i++){
			int n=cluster.get(i).size();
			if(n!=0){
				double[] newcenter={0,0};
				for(int j=0;j<n;j++){
					newcenter[0]+=cluster.get(i).get(j)[0];
					newcenter[1]+=cluster.get(i).get(j)[1];
				}
				newcenter[0]=newcenter[0]/n;
				newcenter[1]=newcenter[1]/n;
				center.set(i, newcenter);
			}
		}
	}
	//求2點的誤差平方
	private double errosquare(double[] element,double[] center){
		double x=element[0]-center[0];
		double y=element[1]-center[1];
		double errosquare=x*x+y*y;
		return errosquare;
	}
	//計算誤差平方和準則函數
	private void countRule(){
		float jcf=0;
		for(int i=0;i<cluster.size();i++){
			for(int j=0;j<cluster.get(i).size();j++)
				jcf+=errosquare(cluster.get(i).get(j),center.get(i));
		jc.add(jcf);
		}
	}
	//核心算法
	private void Kmeans(){
		//初始化各種變量,随機標明中心。初始化聚類
		init();
		//開始循環
		while(true){
			//把每一個點分到聚類中去
			clusterSet();
			//計算目标函數
			countRule();
			//檢查誤差變化。由于我規定的計算循環次數為50次,是以就不用計算這個啦。你要願意用也能夠,就是慢一點
			/*
			if(m!=0){
				if(jc.get(m)-jc.get(m-1)==0)
					break;
			}*/
			if(m>=50)
				break;
			//否則繼續生成新的中心
			setNewCenter();
			m++;
			cluster.clear();
			cluster=initCluster();

		}
	}
           
//僅僅暴露一個接口給外部類
	public void execute(){
		System.out.print("start kmeans\n");
		Kmeans();
		System.out.print("kmeans end\n");
	}
           
//用來在外面列印出來已經分好的聚類
	public void printDataArray(ArrayList<double[]> data,String dataArrayName){
		for(int i=0;i<data.size();i++){
			System.out.print("print:"+dataArrayName+"["+i+"]={"+data.get(i)[0]+","+data.get(i)[1]+"}\n");
		}
		System.out.print("==========================");
	}
}
           

  嗯。代碼就是這樣。凝視寫的非常具體,也都能看得懂。

以下我給一個測試樣例。

import java.util.ArrayList;

public class Test {
	public static void main(String[] args){
		kmeans k=new kmeans(2);
		ArrayList<double[]> dataSet=new ArrayList<double[]>();
		dataSet.add(new double[]{2,2,2});
		dataSet.add(new double[]{1,2,2});
		dataSet.add(new double[]{2,1,2});
		dataSet.add(new double[]{1,3,2});
		dataSet.add(new double[]{3,1,2});
		dataSet.add(new double[]{-2,-2,-2});
		dataSet.add(new double[]{-1,-2,-2});
		dataSet.add(new double[]{-2,-1,-2});
		dataSet.add(new double[]{-3,-1,-2});
		dataSet.add(new double[]{-1,-3,-2});


		k.setDataSet(dataSet);
		k.execute();
		ArrayList<ArrayList<double[]>> cluster=k.getCluster();
		for(int i=0;i<cluster.size();i++){
			k.printDataArray(cluster.get(i), "cluster["+i+"]");
		}
	}
}
           

   沒啥難度,也就是輸入寫初始資料。然後運作k-means在進行分類。最後列印一下。

這個原型代碼非常粗糙。沒有加入聚類個數以及循環次數的變量。這些須要自己動手啦。

2.k-means應用圖像切割

  我們能夠把k-means聚類放在圖像切割上,也就是說把一個顔色的像素分為一類,然後再塗一個顔色。

像這樣。

【Java】K-means算法Java實作以及圖像切割
【Java】K-means算法Java實作以及圖像切割

左邊就是聚類之前的,右邊是聚類之後的 ,看起來還是滿炫酷的。事實上聚類算法也是非常easy擴充到這裡的。 有以下四個提示(由于是作業,我決定先不放馬,不然到時候作業雷同我的學分就咖喱gaygay了):    ①.上面的原型代碼是對二維的資料進行分類,那我們也知道。一個顔色有RGB三種原色構成,也就是說我們僅僅須要 在二維的基礎上。加上一維資料就吼啦。非常easy有木有,改變下數組結構,在距離計算程式設計三維歐式距離就吼。    ②.Java有自帶的圖像處理類,是以讀取資料敲擊友善。我給一點代碼提示哦

//讀取指定檔案夾的圖檔資料,而且寫入數組,這個資料要繼續處理
	private int[][] getImageData(String path){
		BufferedImage bi=null;
		try{
			bi=ImageIO.read(new File(path));
		}catch (IOException e){
			e.printStackTrace();
		}
		int width=bi.getWidth();
		int height=bi.getHeight();
		int [][] data=new int[width][height];
		for(int i=0;i<width;i++)
			for(int j=0;j<height;j++)
				data[i][j]=bi.getRGB(i, j);
		/*測試輸出
		for(int i=0;i<data.length;i++)
			for(int j=0;j<data[0].length;j++)
				System.out.println(data[i][j]);*/
		return data;
	}
	//用來處理擷取的像素資料,提取我們須要的寫入dataItem數組
	private dataItem[][] InitData(int [][] data){
		dataItem[][] dataitems=new dataItem[data.length][data[0].length];
		for(int i=0;i<data.length;i++){
			for(int j=0;j<data[0].length;j++){
				dataItem di=new dataItem();
				Color c=new Color(data[i][j]);
				di.r=(double)c.getRed();
				di.g=(double)c.getGreen();
				di.b=(double)c.getBlue();
				di.group=1;
				dataitems[i][j]=di;
			}
		}
		return dataitems;
	}
           
//介貨是用來輸出圖像的
<pre name="code" class="java">           private void ImagedataOut(String path){
		Color c0=new Color(255,0,0);
		Color c1=new Color(0,255,0);
		Color c2=new Color(0,0,255);
		Color c3=new Color(128,128,128);
		BufferedImage nbi=new BufferedImage(source.length,source[0].length,BufferedImage.TYPE_INT_RGB);
		for(int i=0;i<source.length;i++){
			for(int j=0;j<source[0].length;j++){
				if(source[i][j].group==0)
					nbi.setRGB(i, j, c0.getRGB());
				else if(source[i][j].group==1)
					nbi.setRGB(i, j, c1.getRGB());
				else if(source[i][j].group==2)
					nbi.setRGB(i, j, c2.getRGB());
				else if (source[i][j].group==3)
					nbi.setRGB(i, j, c3.getRGB());
				//Color c=new Color((int)center[source[i][j].group].r,
				//		(int)center[source[i][j].group].g,(int)center[source[i][j].group].b);
				//nbi.setRGB(i, j, c.getRGB());
			}
		}
		try{
			ImageIO.write(nbi, "jpg", new File(path));
		}catch(IOException e){
			e.printStackTrace();
			}
	}
           

    非常舒爽。你問我dataItem是啥?等我交完作業我就告訴你。     ③.有一點不同的是。注意資料格式。胖胖開始用的就是int類型,結果在計算新的聚類中心的時候溢出了呢。。

。所幸鵬鵬改成了double。可是鵬鵬在計算距離的時候又寫錯了,最後還是機智的胖胖鵬解決掉了全部的bug。

    ④.注意讀取圖檔的時候保護好資料的順序,也就是用一個二維數組來存儲,這樣在寫的時候就不用記錄像素點的位置,輸出的時候也非常友善。    就是這些。。。

等我作業交完就來一次完整的代碼解說。