天天看點

常見機器學習算法學習——KNN(K鄰近)

1、 算法簡述    

文章中描述性内容,多來自維基百科KNN。

KNN( k-nearest neighbors algorithm)是一種非參數、有監督算法,由T. M. COVER, P. E. HART, Hart PE 在1967年提出,後被廣泛應用于模式識别領域,既可用于分類也可以用于回歸。

KNN是一種懶學習(lazy learning)算法,沒有顯式的訓練過程;将多元特征空間中已經帶有分類(分類問題)或屬性值(回歸問題)的資料集看做訓練集,計算預測樣本與訓練集中個樣本的鄰近度(實用向量距離展現,如歐氏距離,馬氏距離等),取K個最鄰近樣本,根據鄰近樣本投票(分類問題)或權重(回歸問題),确定預測樣本的分類或屬性。是以,KNN算法的計算量非常大。

常見機器學習算法學習——KNN(K鄰近)

标題

借用維基百科這張被借用無數次的圖檔來大緻描述KNN分類。圖中藍色及紅色色塊表示訓練集中兩類兩類已打标簽的樣本,綠色圓形色塊為待确定分類的預測集,采用歐氏距離(及實際距離)為鄰近度準則,當K=3時,有效色塊即為圖中黑色圓形實線包圍的色塊,範圍内紅色三角形類别色塊2個,方形藍色色塊1個,則訓練集中的兩類色塊紅色占多,則預測集綠色圓形色塊被歸入紅色三角一類;當K=5時,有效色塊即為圖中黑色圓形虛線包圍的色塊,範圍内紅色三角形類别色塊2個,方形藍色色塊3個,則訓練集中的兩類色塊藍色占多,則預測集綠色圓形色塊被歸入藍色方形一類;

2、算法特點

3、相似算法

4、算法實作

KNN最友善的實作方法當然是調用sklearn中的KNN相應方法,為了更清晰展示KNN的基本原理,此處給出僅實作KNN基本過程的玩具代碼:

# -*- coding: utf-8 -*-
"""
Created on Tue Sep 25 01:18:35 2018

@author: yzp1011
"""
import numpy as np
import pandas as pd
from functools import partial 


class KNN(object):
    def eu_distance(self,p1,p2):
        return np.linalg.norm(p1 - p2)

    def get_neighbors(self,train_set,test_set,k):
        vec = pd.DataFrame([partial(self.eu_distance,x) for x in train_set[:,0:-1]],columns = ['train'])
        
        print('vec\'s shae:{},label shape:{}'.format(vec.shape,len(train_set)))
        vec['label'] = train_set[:,-1].T
        try:
            assert test_set.shape[1] > 0
            for v in test_set:
                res = vec['train'].map(lambda x:x(v))
                out = vec.iloc[res.nlargest(k).index]['label'].mode()[0]
                print('x:{}-mode:{}'.format(v,out)) 
#        當預測向量長度為1時
        except IndexError:
            res = vec['train'].map(lambda x:x(test_set))
            out = vec.iloc[res.nlargest(k).index]['label'].mode()[0]
            print('x:{}-mode:{}'.format(test_set,out))           
        return self
    
    
if __name__ == '__main__':
    model = KNN()
    train_set = [[1,1,1,1], [2,2,2,1], [1,1,3,1], [4,4,4,2], [0,0,0,1], [4,4.5,4,2]]
    test_instance = [5,5,5]
    k = 3
    neigbors = model.get_neighbors(np.array(train_set),np.array(test_instance),k)