天天看點

KNN 算法的python實作 疊代訓練方式,将最近的測試樣例作為訓練樣例擴大訓練集

KNN算法的原理不在贅述,直接介紹源代碼使用方式。

将代碼儲存為.py格式,預設使用的資料是代碼檔案所在目錄下data目錄下的 knn_train.txt 和knn_test.txt 兩個檔案分别作為訓練樣例和測試樣例,預設k_value=3。以上參數可以在源代碼中修改,也可以使用指令行參數傳入,參考以下啟動方式:

python knn.py train.txt test.txt 4

指令後三個參數分别是訓練集,測試集,和k_value值。

程式預設将已經分類的測試樣例作為訓練樣例使用,若要去掉該項,直接删除源代碼中最後兩行代碼即可。

特别感謝:

點選打開連結

特别感謝 Android路上的人 的測試資料。

python 源代碼如下:

__author__ = 'Administrator'
############      KNN           ###############
#####        tm_year=2016, tm_mon=3, tm_mday=16, tm_hour=11, tm_min=51, tm_sec=3, tm_wday=2, tm_yday=76, tm_isdst=0          #####

import re
import sys


k_value = 3 ################   the parimater    ###################
DataLength = 100
tr_data = [0 for i in range(DataLength)]
test_data = [0  for i in range(DataLength)]
tr_lg = test_lg = wd = 0
types = set()

def knn(k,i,lg):
    ls = [0 for x in range(tr_lg)]
    # print tr_lg,wd,types
    for m in range(tr_lg):
        s=0
        for n in range(wd):
         s += (float(test_data[i][n]) - float(tr_data[m][n+1]))*(float(test_data[i][n]) - float(tr_data[m][n+1]))
        ls[m] = s
    ll = [0 for j in range(lg)]
    tp = list(types)
    # print ls
    for j in range(k):
        m = ls.index(min(ls))
        n = tp.index(tr_data[m][0])
        ll[n] += 1

        ls[ls.index(min(ls))] = max(ls)
    #print ls
    # print tp,ll
    return tp[ll.index(max(ll))]



if __name__ == '__main__':
    #for a in sys.argv:
    #    print a
    train = "data/knn_train.txt"
    test = "data/knn_test.txt"
    if  len(sys.argv)>1:
        train = sys.argv[1]
        test = sys.argv[2]
    if len(sys.argv)>3:
        k_value = int(sys.argv[3])
    fp1 = open(train,"r")
    fp2 = open(test,"r")
    i=0
    for line in fp1:
        line = re.sub(r"\n\r","",line)
        ls = line.split()
        tr_data[i] = ls
        types.add(ls[0])
        i+=1
    tr_lg = i
    i=0
    for line in fp2:
        line = re.sub(r"\n\r","",line)
        ls = line.split()
        test_data[i] = ls
        i+=1
    test_lg = i
    wd = len(test_data[0])
    fp1.close()
    fp2.close()
    #print tr_lg,test_lg,wd,types

    if k_value > tr_lg:
        k_value = tr_lg
    tp = list(types)
    for i in range(test_lg):
        s = knn(k_value,i,len(types))
        print test_data[i],"\ttype:",s

        '''   apply the newest test data as train data   '''
        tr_data[tr_lg] = list(s) + test_data[i]
        tr_lg += 1
        '''   apply the newest test data as train data   '''