天天看點

KNN的python實作

KNN的原理就是取最近的k個點,然後将類别最多的作為預測類别,原來簡單,代碼也很簡單,如下:

# coding=utf8
import sys
import os
import numpy as np
from numpy import *

reload(sys)
sys.setdefaultencoding('utf-8')
os.chdir(r'D:\Study\ML\MLAction')

def euclidean_dist(x1,x2):
    return np.sqrt(x1.T*x2)

def kNNClassify(X,label,predict_X,k):
    dist=np.sum(np.power(X-predict_X,2),axis=1)
    sort_index=np.argsort(dist,axis=0)
    sort_index=sort_index.reshape(1, len(sort_index)).tolist()[0]
    label_stata={}
    for i in range(k):
        lbl=label[sort_index.index(i)]
        label_stata[lbl]=label_stata.get(lbl,0)+1
    sort_label=sorted(label_stata.items(),key=lambda x:x[1],reverse=True)
    return sort_label[0][0]

def createDataSet():
    group = mat([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
    labels = ['A','A','B','B']
    return group, labels

group,labels=createDataSet()
result=kNNClassify(group,labels,[0,0,1],3)
print(result)