天天看點

Pytorch:線性代數

線性代數

PyTorch的線性函數主要封裝了Blas和Lapack,其用法和接口都與之類似。常用的線性代數函數如表3-7所示。

表3-7: 常用的線性代數函數

函數 功能
trace 對角線元素之和(矩陣的迹)
diag 對角線元素
triu/tril 矩陣的上三角/下三角,可指定偏移量
mm/bmm 矩陣乘法,batch的矩陣乘法
addmm/addbmm/addmv/addr/badbmm.. 矩陣運算
t 轉置
dot/cross 内積/外積
inverse 求逆矩陣
svd 奇異值分解

具體使用說明請參見官方文檔1,需要注意的是,矩陣的轉置會導緻存儲空間不連續,需調用它的

.contiguous

方法将其轉為連續。

  1. http://pytorch.org/docs/torch.html#blas-and-lapack-operations↩

In [88]:

b = a.t()
b.is_contiguous()
      

Out[88]:

False      

In [89]:

b.contiguous()
      

Out[89]:

tensor([[ 0.,  9.],
        [ 3., 12.],      

繼續閱讀