天天看點

pytorch不用for循環計算一個矩陣各行之間的L1 、L2範數距離和餘弦距離

最近需要在pytorch中計算一個K×N維的矩陣各行之間的歐式距離,并輸出距離矩陣(K×K維)。最開始使用的for循環,無奈效率太低,是以需要尋找用矩陣的方法來實作,比較好找的函數是:

from scipy.spatial import distance
scipy.spatial.distance.pdist()
           

以及

from scipy.spatial import distance
scipy.spatial.distance.cdist()
           

具體怎麼用我就不在這裡細說了,請自行搜尋。但是這兩個函數我都不好直接調用,因為我需要用在pytorch架構下自定義的loss函數中使用,如果使用上面的函數,就需要将tensor資料轉換為numpy資料,造成梯度丢失。經過近期的學習發現幾種解決辦法:

1. torch中自帶的pdist()。但是這個函數輸出結果為距離向量,而不是距離矩陣。距離向量是距離矩陣中上三角的元素。

import torch
import torch.tensor as tensor
import torch.nn.functional as F
a = tensor([[1., 1., 1.],
            [2., 2., 2.],
            [3., 3., 3.],
            [4., 4., 4.]])  #建立tensor
 
print(a)
d=F.pdist(a, p=2)
print(d)
           

輸出結果:

tensor([[1., 1., 1.],
        [2., 2., 2.],
        [3., 3., 3.],
        [4., 4., 4.]])
tensor([1.7321, 3.4641, 5.1962, 1.7321, 3.4641, 1.7321])
           

2.  自定義pdist,代碼來自https://blog.csdn.net/LoveCarpenter/article/details/85048291。但是該函數隻能用來計算歐式距離(L2範數),而且對角線上的元素不是0,而是一個極小的數1e-4。

import torch
import torch.tensor as tensor
def pdists(A, squared = False, eps = 1e-8):
	prod = torch.mm(A, A.t())
	norm = prod.diag().unsqueeze(1).expand_as(prod)
	res = (norm + norm.t() - 2 * prod).clamp(min = 0)
	return res if squared else res.clamp(min = eps).sqrt()

#建立一個tensor
a = tensor([[1., 2., 3.],
            [4., 5., 6.],
            [7., 8., 9.],
            [10., 11., 12.]])
print(a)
c=pdists(a, squared = False)
print(c)
           

輸出結果:

tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 7.,  8.,  9.],
        [10., 11., 12.]])
tensor([[1.0000e-04, 5.1962e+00, 1.0392e+01, 1.5588e+01],
        [5.1962e+00, 1.0000e-04, 5.1962e+00, 1.0392e+01],
        [1.0392e+01, 5.1962e+00, 1.0000e-04, 5.1962e+00],
        [1.5588e+01, 1.0392e+01, 5.1962e+00, 1.0000e-04]])
           

3. 最終解決辦法。在自行閱讀torch.nn.functional.pdist的文檔介紹時發現了這麼一句話

pytorch不用for循環計算一個矩陣各行之間的L1 、L2範數距離和餘弦距離

簡單翻譯:計算輸入中每​​對行向量之間的p範數距離。 這與torch.norm(input[:, None] - input, dim=2, p=p)的對角線以外的上部三角形部分相同。 如果行是連續的,此功能将更快。

這就是赤果果的提示啊!這個方法的好處是可以根據需要選用L1、L2範數或者其他範數。

import torch
import torch.tensor as tensor
a = tensor([[1., 1., 1.],
            [2., 2., 2.],
            [3., 3., 3.],
            [4., 4., 4.]])  #建立tensor
b=torch.norm(a[:, None]-a, dim=2, p=2)
print(b)
           

輸出結果:

tensor([[0.0000, 1.7321, 3.4641, 5.1962],
        [1.7321, 0.0000, 1.7321, 3.4641],
        [3.4641, 1.7321, 0.0000, 1.7321],
        [5.1962, 3.4641, 1.7321, 0.0000]])
           

-------------------------------

如果想實作餘弦距離怎麼辦呢,其實也可以自己寫個簡單的函數,餘弦距離公式經過推導後可以變成兩個部分相除(以下公式來自網絡):

pytorch不用for循環計算一個矩陣各行之間的L1 、L2範數距離和餘弦距離
pytorch不用for循環計算一個矩陣各行之間的L1 、L2範數距離和餘弦距離

是以自己寫個函數實作一下就好啦:

import torch
def cosinematrix(A):
    prod = torch.mm(A, A.t())#分子
    norm = torch.norm(A,p=2,dim=1).unsqueeze(0)#分母
    cos = prod.div(torch.mm(norm.t(),norm))
    return cos

# 使用
d_matrix=cosinematrix(inputs)
           

繼續閱讀