天天看點

CS231n knn python 課後作業knn程式predict_lable

knn程式

http://cs231n.github.io/classification/

L1 distance

d1(I1,I2)=∑p|Ip1−Ip2|
CS231n knn python 課後作業knn程式predict_lable
def predict(self, X):
""" X is N x D where each row is an example we wish to predict label for """
num_test = X.shape[0]
# lets make sure that the output type matches the input type
Ypred = np.zeros(num_test, dtype = self.ytr.dtype)

# loop over all test rows
for i in xrange(num_test):
  # find the nearest training image to the i'th test image num_test 10000
  # using the L1 distance (sum of absolute value differences)記錄的是每一行(每個圖檔,一共10000個)[i:](3072個)的計算值
  distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)
  #self.Xtr:50000*3072; X[i,:]:1*3072,根據廣播,每行都會相減,是以distances結果是50000*1,然後找最小值的位置argmin
  min_index = np.argmin(distances) # get the index with smallest distance
  Ypred[i] = self.ytr[min_index] # predict the label of the nearest example

return Ypred
           

2018.5.11

def compute_distances_no_loops(self,X):
    num_test = X.shape[0]
    num_train = self.X_train.shape[0]
    dists = np.zeros((num_test, num_train))
    test_sum=np.sum(np.square(X),axis=1)
    train_sum=np.sum(np.square(self.X_train),axis=1)
    inner_product=np.dot(X,self.X_train.T)
    dists=np.sqrt(-2*inner_product+test_sum.reshape(-1,1)+train_sum)
    return dists
           

不用循環實作計算,參見:

https://blog.csdn.net/zhyh1435589631/article/details/54236643
CS231n knn python 課後作業knn程式predict_lable

(a−b)2‾‾‾‾‾‾‾‾√=a2+b2−2ab‾‾‾‾‾‾‾‾‾‾‾‾‾√ ( a − b ) 2 = a 2 + b 2 − 2 a b

因為broadcast,最後想實作M*N,而test_sum為1*M,train_sum為1*N,是以隻要把test_sum轉置即可,其他的不用改,最後會輸出M*N矩陣的。

predict_lable

def predict_labels(self, dists, k=1):
           
https://blog.csdn.net/guangtishai4957/article/details/79950117

predict_labels函數中倒數第二行y_pred[i] = np.argmax(np.bincount(closest_y))的用法說明

# bincount函數的用法  
x = np.array([0, 1, 1, 3, 3, 3, 3, 5])  
# bincount函數的傳回結果為一個清單,其中清單的索引于x中的值對應,  
# 例如此例中x中的最大值為5,則bincount函數傳回的清單的索引為0到5,每個索引對應的值為該索引對應的數字出現的次數(有點繞,看輸出結果了解一下)  
y = np.bincount(x)  
print(y)  
輸出結果-》 [1 2 0 4 0 1]  
# numpy裡的argmax函數傳回括号内參數中最大值對應的索引值,y的最大值為4,對應的索引值為3,是以傳回結 果為3  
# 這兩個函數的結合是以實作了對多個類别中出現次數最多的類别進行統計并輸出的功能!!!  
z = np.argmax(y)  
輸出結果為 3  
           

繼續閱讀