天天看點

關于CS231N-Assignment1-KNN中no-loop矩陣乘法代碼的講解

在使用無循環的算法進行計算距離的效率是很高的

可以看到No loop算法使用的時間遠遠小于之前兩種算法

Two loop version took 56.785069 seconds
One loop version took 136.449761 seconds
No loop version took 0.591535 seconds   #很快!           

實作代碼主要為以下這一段:

其中X為500×3072的矩陣(測試矩陣)

X_train為5000×3072的矩陣(訓練矩陣)

dists 為500×5000的矩陣(距離矩陣)

題中的目的就是将X中每一行的像素數值與X_train中每一行的像素數值(3072個)進行距離運算得出歐氏距離(L2)再儲存到dists中

核心公式

test_sum = np.sum(np.square(X), axis=1)  # num_test x 1
train_sum = np.sum(np.square(self.X_train), axis=1)  # num_train x 1
inner_product = np.dot(X, self.X_train.T)  # num_test x num_train
dists = np.sqrt(-2 * inner_product + test_sum.reshape(-1, 1) + train_sum)  # broadcast           

公式講解:

假設現在有三個矩陣:A(X)、B(X_train)、C(dists )

将維數縮小以友善操作,稍微進行推導,就可以得出上面的公式了

推導過程如下:

繼續閱讀