最近需要在pytorch中計算一個K×N維的矩陣各行之間的歐式距離,并輸出距離矩陣(K×K維)。最開始使用的for循環,無奈效率太低,是以需要尋找用矩陣的方法來實作,比較好找的函數是:
from scipy.spatial import distance
scipy.spatial.distance.pdist()
以及
from scipy.spatial import distance
scipy.spatial.distance.cdist()
具體怎麼用我就不在這裡細說了,請自行搜尋。但是這兩個函數我都不好直接調用,因為我需要用在pytorch架構下自定義的loss函數中使用,如果使用上面的函數,就需要将tensor資料轉換為numpy資料,造成梯度丢失。經過近期的學習發現幾種解決辦法:
1. torch中自帶的pdist()。但是這個函數輸出結果為距離向量,而不是距離矩陣。距離向量是距離矩陣中上三角的元素。
import torch
import torch.tensor as tensor
import torch.nn.functional as F
a = tensor([[1., 1., 1.],
[2., 2., 2.],
[3., 3., 3.],
[4., 4., 4.]]) #建立tensor
print(a)
d=F.pdist(a, p=2)
print(d)
輸出結果:
tensor([[1., 1., 1.],
[2., 2., 2.],
[3., 3., 3.],
[4., 4., 4.]])
tensor([1.7321, 3.4641, 5.1962, 1.7321, 3.4641, 1.7321])
2. 自定義pdist,代碼來自https://blog.csdn.net/LoveCarpenter/article/details/85048291。但是該函數隻能用來計算歐式距離(L2範數),而且對角線上的元素不是0,而是一個極小的數1e-4。
import torch
import torch.tensor as tensor
def pdists(A, squared = False, eps = 1e-8):
prod = torch.mm(A, A.t())
norm = prod.diag().unsqueeze(1).expand_as(prod)
res = (norm + norm.t() - 2 * prod).clamp(min = 0)
return res if squared else res.clamp(min = eps).sqrt()
#建立一個tensor
a = tensor([[1., 2., 3.],
[4., 5., 6.],
[7., 8., 9.],
[10., 11., 12.]])
print(a)
c=pdists(a, squared = False)
print(c)
輸出結果:
tensor([[ 1., 2., 3.],
[ 4., 5., 6.],
[ 7., 8., 9.],
[10., 11., 12.]])
tensor([[1.0000e-04, 5.1962e+00, 1.0392e+01, 1.5588e+01],
[5.1962e+00, 1.0000e-04, 5.1962e+00, 1.0392e+01],
[1.0392e+01, 5.1962e+00, 1.0000e-04, 5.1962e+00],
[1.5588e+01, 1.0392e+01, 5.1962e+00, 1.0000e-04]])
3. 最終解決辦法。在自行閱讀torch.nn.functional.pdist的文檔介紹時發現了這麼一句話
簡單翻譯:計算輸入中每對行向量之間的p範數距離。 這與torch.norm(input[:, None] - input, dim=2, p=p)的對角線以外的上部三角形部分相同。 如果行是連續的,此功能将更快。
這就是赤果果的提示啊!這個方法的好處是可以根據需要選用L1、L2範數或者其他範數。
import torch
import torch.tensor as tensor
a = tensor([[1., 1., 1.],
[2., 2., 2.],
[3., 3., 3.],
[4., 4., 4.]]) #建立tensor
b=torch.norm(a[:, None]-a, dim=2, p=2)
print(b)
輸出結果:
tensor([[0.0000, 1.7321, 3.4641, 5.1962],
[1.7321, 0.0000, 1.7321, 3.4641],
[3.4641, 1.7321, 0.0000, 1.7321],
[5.1962, 3.4641, 1.7321, 0.0000]])
-------------------------------
如果想實作餘弦距離怎麼辦呢,其實也可以自己寫個簡單的函數,餘弦距離公式經過推導後可以變成兩個部分相除(以下公式來自網絡):
是以自己寫個函數實作一下就好啦:
import torch
def cosinematrix(A):
prod = torch.mm(A, A.t())#分子
norm = torch.norm(A,p=2,dim=1).unsqueeze(0)#分母
cos = prod.div(torch.mm(norm.t(),norm))
return cos
# 使用
d_matrix=cosinematrix(inputs)