天天看點

2.2 KNN算法實作

1 源碼下載下傳

下載下傳代碼

2 代碼截圖

2.2 KNN算法實作

3 KNN代碼實作

import random
import math
from operator import itemgetter

"""
 定義加載資料函數
 fileName: 檔案名稱
 split:分割點,将原始資料分為訓練集和測試集
 trainingData:訓練集
 testData:測試集
"""
def loadData(fileName,split,trainingData,testData):
    file = open(fileName,'r')
    lines = file.read()
    rows = lines.strip('\n').split('\n')
    for x in range( len(rows) ):
       row = rows[x].strip(',').split(',')
       if random.random() < split:
           trainingData.append(row)
       else:
           testData.append(row)

"""
得到最近的點
trainingDataList:訓練集
testData:測試集
"""
def getNeighbors(trainingDataList,testData,K):
    dimension = len(testData) - 1
    distances = []
    for trainingData in trainingDataList :
        distance = euclideanDistance(trainingData,testData,dimension)
        distances.append( ( trainingData,distance) )
    distances = sorted(distances,key=itemgetter(1) )
    neighbors = []
    for x in range(K):
        neighbors.append(distances[x][0])
    return neighbors

"""
進行投票,小數服從多數
例如 A:3票, B:2票,C:5票
那麼結果為 C
"""
def vote(neighbors):
    dict = {}
    for neighbor in neighbors:
        type = neighbor[ len(neighbor) - 1 ]
        if type in dict:
            dict[type] += 1
        else:
            dict[type] = 1
    items = sorted(dict.items(),key=itemgetter(1),reverse=True)
    return items[0][0]

"""
估算兩點之間的歐幾裡得距離
point1 : 第一個點
point2 : 第二個點
dimension : 次元,比如 x1(7.5) 次元為1, x2(1,5)次元為2, x3(1,6,32)次元為3
"""
def euclideanDistance( point1,point2, dimension):
    distance = 0
    for x in range(dimension):
        distance += pow( float( point1[x] ) - float( point2[x] ) ,2 )
    return math.sqrt(distance)

"""
得到預測的精确度
"""
def getAccuracy(forecastList):
    errorCount = 0
    for forecast in forecastList:
        if forecast[1] != forecast[0][-1]:
            errorCount += 1
    accuracy = 1 - float( errorCount / len(forecastList) )
    return  accuracy


"""
main函數
"""
def main():
    trainingDataList = []
    testDataList = []
    split = 0.9
    K = 5
    fileName = "D:/workspace/MachineLearning/07stage/1-fundamental/01/KNN/irisdata.txt"
    loadData(fileName,split,trainingDataList,testDataList)
    forecastList = []
    for testData in testDataList:
        neighbors = getNeighbors(trainingDataList,testData,K)
        voteResult = vote(neighbors)
        forecastList.append( ( testData,voteResult ) )
    accuracy = getAccuracy(forecastList)
    print("準确度:{}% ,錯誤率 : {} %".format( accuracy * 100 , ( 1 - accuracy) * 100 ),)
for x in range(100):
    main()

           

繼續閱讀