本文在weka下,主要使用高斯函數權重,選取最優K值進行優化。你也可以參考網上文檔,将如下文的
KNN_lsh.java
複制到某一目錄并進行相關設定,進而在weka gui中測試改進。
檔案目錄:

package cug.lsh;
import weka.classifiers.*;
import weka.core.*;
import java.util.*;
@SuppressWarnings("serial")
public class KNN_lsh extends Classifier {
private Instances m_Train;
private int m_kNN;
public void setM_kNN(int m_kNN) {
this.m_kNN = m_kNN;
}
public void buildClassifier(Instances data) throws Exception {
m_Train = new Instances(data);
}
public double[] distributionForInstance(Instance instance) throws Exception {
Instances instances= findNeighbors(instance, m_kNN);
return computeDistribution(instances, instance);
}
private Instances findNeighbors(Instance instance, int kNN) {
double distance;
List<HasDisInstances> neighborlist = new LinkedList<>();
for (int i = 0; i < m_Train.numInstances(); i++) {
Instance trainInstance = m_Train.instance(i);
distance = distance(instance, trainInstance);
HasDisInstances hasDisInstances=new HasDisInstances(distance,trainInstance);
if(i==0 || (i<kNN-1 && neighborlist.get(neighborlist.size()-1).distance<distance))
neighborlist.add(hasDisInstances);
else{
for (int j = 0; j < kNN && j<neighborlist.size(); j++) {
if(distance<neighborlist.get(j).distance){
neighborlist.add(j, hasDisInstances);
break;
}
}
}
}
int min=Math.min(kNN, neighborlist.size());
Instances instances=new Instances(m_Train,min);
for(int i=0;i<min;i++){
instances.add(neighborlist.get(i).instance);
}
return instances;
}
private double distance(Instance first, Instance second) {
double distance = 0;
for (int i = 0; i < m_Train.numAttributes(); i++) {
if (i == m_Train.classIndex())
continue;
if((int)first.value(i)!=(int)second.value(i)){
distance+=1;
}
// //此處修改距離計算公式
// distance+=(second.value(i)-first.value(i))*(second.value(i)-first.value(i));//歐基米德爾公式
// distance+=second.value(i)*Math.log(second.value(i)/first.value(i));最大熵
// distance+=Math.pow((second.value(i)-first.value(i)), 2)/first.value(i);//卡方距離
}
// distance=Math.sqrt(distance);
return distance;
}
private double[] computeDistribution(Instances data, Instance instance) throws Exception {
double[] prob=new double[data.numClasses()];
for (int i=0;i<data.numInstances();i++){
int classVal=(int)data.instance(i).classValue();
double x=distance(instance, data.instance(i));
prob[classVal] +=1+Math.exp(-x*x/0.18);//c=0.3
}
Utils.normalize(prob);
return prob;
}
private class HasDisInstances{
double distance;
Instance instance;
public HasDisInstances(double distance, Instance instance) {
this.distance = distance;
this.instance = instance;
}
}
}
package cug.lsh;
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
public class KNN_lsh_use {
public static void main(String[] args) throws Exception {
Instances train=DataSource.read("E:/DataLearing/data/credit-g.arff");
train.setClassIndex(train.numAttributes()-1);
int size=(int) (train.numInstances()*0.2);//構造測試集
Instances test = new Instances(train,size);
test.setClassIndex(test.numAttributes()-1);
for (int i = 0; i < size; i++) {
test.add(train.instance(i));
train.delete(i);
}
KNN_lsh classifier=new KNN_lsh();
//計算最佳k值
int optiK=0;
int prob=0;//臨時變量,正确個數
for (int m_kNN = 3; m_kNN < Math.sqrt(train.numInstances())+3 && m_kNN<=20; m_kNN++) {
// long oldTime=System.currentTimeMillis();
classifier.setM_kNN(m_kNN);
classifier.buildClassifier(train);
int count=0;
for (int i = 0; i < test.numInstances(); i++){
if (classifier.classifyInstance(test.instance(i)) == test.instance(i).classValue())
count++;
}
if(count>prob){
optiK=m_kNN;
prob=count;
}
// long newTime=System.currentTimeMillis();
// System.out.println(1.0*count/test.numInstances()+","+m_kNN+","+0.001*(newTime-oldTime));
}
System.out.println(1.0*prob/test.numInstances()+","+optiK);
}
}