1 源碼下載下傳
下載下傳代碼
2 代碼截圖
![](https://img.laitimes.com/img/9ZDMuAjOiMmIsIjOiQnIsIyZwpmLwIDO1ETM0IjM3IjMwkTMwIzLc52YucWbp5GZzNmLn9Gbi1yZtl2Lc9CX6MHc0RHaiojIsJye.jpg)
3 KNN代碼實作
import random
import math
from operator import itemgetter
"""
定義加載資料函數
fileName: 檔案名稱
split:分割點,将原始資料分為訓練集和測試集
trainingData:訓練集
testData:測試集
"""
def loadData(fileName,split,trainingData,testData):
file = open(fileName,'r')
lines = file.read()
rows = lines.strip('\n').split('\n')
for x in range( len(rows) ):
row = rows[x].strip(',').split(',')
if random.random() < split:
trainingData.append(row)
else:
testData.append(row)
"""
得到最近的點
trainingDataList:訓練集
testData:測試集
"""
def getNeighbors(trainingDataList,testData,K):
dimension = len(testData) - 1
distances = []
for trainingData in trainingDataList :
distance = euclideanDistance(trainingData,testData,dimension)
distances.append( ( trainingData,distance) )
distances = sorted(distances,key=itemgetter(1) )
neighbors = []
for x in range(K):
neighbors.append(distances[x][0])
return neighbors
"""
進行投票,小數服從多數
例如 A:3票, B:2票,C:5票
那麼結果為 C
"""
def vote(neighbors):
dict = {}
for neighbor in neighbors:
type = neighbor[ len(neighbor) - 1 ]
if type in dict:
dict[type] += 1
else:
dict[type] = 1
items = sorted(dict.items(),key=itemgetter(1),reverse=True)
return items[0][0]
"""
估算兩點之間的歐幾裡得距離
point1 : 第一個點
point2 : 第二個點
dimension : 次元,比如 x1(7.5) 次元為1, x2(1,5)次元為2, x3(1,6,32)次元為3
"""
def euclideanDistance( point1,point2, dimension):
distance = 0
for x in range(dimension):
distance += pow( float( point1[x] ) - float( point2[x] ) ,2 )
return math.sqrt(distance)
"""
得到預測的精确度
"""
def getAccuracy(forecastList):
errorCount = 0
for forecast in forecastList:
if forecast[1] != forecast[0][-1]:
errorCount += 1
accuracy = 1 - float( errorCount / len(forecastList) )
return accuracy
"""
main函數
"""
def main():
trainingDataList = []
testDataList = []
split = 0.9
K = 5
fileName = "D:/workspace/MachineLearning/07stage/1-fundamental/01/KNN/irisdata.txt"
loadData(fileName,split,trainingDataList,testDataList)
forecastList = []
for testData in testDataList:
neighbors = getNeighbors(trainingDataList,testData,K)
voteResult = vote(neighbors)
forecastList.append( ( testData,voteResult ) )
accuracy = getAccuracy(forecastList)
print("準确度:{}% ,錯誤率 : {} %".format( accuracy * 100 , ( 1 - accuracy) * 100 ),)
for x in range(100):
main()