天天看點

weka knn算法改進與實作

本文在weka下,主要使用高斯函數權重,選取最優K值進行優化。你也可以參考網上文檔,将如下文的​

​KNN_lsh.java​

​複制到某一目錄并進行相關設定,進而在weka gui中測試改進。

檔案目錄:

weka knn算法改進與實作
package cug.lsh;

import weka.classifiers.*;
import weka.core.*;
import java.util.*;

@SuppressWarnings("serial")
public class KNN_lsh extends Classifier {

  private Instances m_Train;
  private int m_kNN;
  
  public void setM_kNN(int m_kNN) {
    this.m_kNN = m_kNN;
  }
  
  public void buildClassifier(Instances data) throws Exception {
    m_Train = new Instances(data);  
    
  }

  public double[] distributionForInstance(Instance instance) throws Exception {

    Instances instances= findNeighbors(instance, m_kNN);
    return computeDistribution(instances, instance);
  }
  
  private Instances findNeighbors(Instance instance, int kNN) {
    double distance;  
    List<HasDisInstances> neighborlist = new LinkedList<>();
    
    for (int i = 0; i < m_Train.numInstances(); i++) {
      Instance trainInstance = m_Train.instance(i);
      distance = distance(instance, trainInstance);
      HasDisInstances hasDisInstances=new HasDisInstances(distance,trainInstance);
      
      if(i==0 || (i<kNN-1 && neighborlist.get(neighborlist.size()-1).distance<distance))
        neighborlist.add(hasDisInstances);
      else{
        for (int j = 0; j < kNN && j<neighborlist.size(); j++) {
          if(distance<neighborlist.get(j).distance){
            neighborlist.add(j, hasDisInstances);
            break;
          }
        }
      }
    }
    
    int min=Math.min(kNN, neighborlist.size());
    Instances instances=new Instances(m_Train,min);
    for(int i=0;i<min;i++){
      instances.add(neighborlist.get(i).instance);
    }
    return instances;
  }

  private double distance(Instance first, Instance second) {

    double distance = 0;
    for (int i = 0; i < m_Train.numAttributes(); i++) {
      if (i == m_Train.classIndex())
        continue;
      if((int)first.value(i)!=(int)second.value(i)){
            distance+=1;
          }
//      //此處修改距離計算公式
//      distance+=(second.value(i)-first.value(i))*(second.value(i)-first.value(i));//歐基米德爾公式
//      distance+=second.value(i)*Math.log(second.value(i)/first.value(i));最大熵
//      distance+=Math.pow((second.value(i)-first.value(i)), 2)/first.value(i);//卡方距離
    }
//    distance=Math.sqrt(distance);
    return distance;
  }

  private double[] computeDistribution(Instances data, Instance instance) throws Exception {
    
      double[] prob=new double[data.numClasses()];

      for (int i=0;i<data.numInstances();i++){
        int classVal=(int)data.instance(i).classValue();
        double x=distance(instance, data.instance(i));
        prob[classVal] +=1+Math.exp(-x*x/0.18);//c=0.3
      }
    Utils.normalize(prob);
    return prob;
  }

  private class HasDisInstances{
    double distance;
    Instance instance;
    public HasDisInstances(double distance, Instance instance) {
      this.distance = distance;
      this.instance = instance;
    }
  }
}      
package cug.lsh;

import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;

public class KNN_lsh_use {
  public static void main(String[] args) throws Exception {
    Instances train=DataSource.read("E:/DataLearing/data/credit-g.arff");    
        train.setClassIndex(train.numAttributes()-1);  
        
        
        int size=(int) (train.numInstances()*0.2);//構造測試集
        Instances test = new Instances(train,size);
        test.setClassIndex(test.numAttributes()-1);
        for (int i = 0; i < size; i++) {
          test.add(train.instance(i));
      train.delete(i);
    }     
        
        KNN_lsh classifier=new KNN_lsh();       
      //計算最佳k值
        int optiK=0;
        int prob=0;//臨時變量,正确個數
        for (int m_kNN = 3; m_kNN < Math.sqrt(train.numInstances())+3  && m_kNN<=20; m_kNN++) {
//          long oldTime=System.currentTimeMillis();
          classifier.setM_kNN(m_kNN);
            classifier.buildClassifier(train);
    
            int count=0;         
            for (int i = 0; i < test.numInstances(); i++){ 
                if (classifier.classifyInstance(test.instance(i)) == test.instance(i).classValue())            
                    count++;   
            }
            if(count>prob){
              optiK=m_kNN;
              prob=count;
            }
//            long newTime=System.currentTimeMillis();
//            System.out.println(1.0*count/test.numInstances()+","+m_kNN+","+0.001*(newTime-oldTime));
    }
        
        System.out.println(1.0*prob/test.numInstances()+","+optiK);
        
  }
}