天天看點

PyTorch 兩大轉置函數 transpose() 和 permute(), 以及RuntimeError: invalid argument 2: view size is not compati

關心差别的可以直接看[3.不同點]和[4.連續性問題]

前言

pytorch

中轉置用的函數就隻有這兩個

  1. transpose()

  2. permute()

注意隻有

transpose()

有字尾格式:

transpose_()

:字尾函數的作用是簡化如下代碼:

x = x.transpose(0,1)
等價于
x.transpose_()
# 相當于x = x + 1 簡化為 x+=1
           

這兩個函數功能相同,有一些在記憶體占用上的細微差別,但是不影響程式設計、可以忽略

1. 官方文檔

transpose()

torch.transpose(input, dim0, dim1, out=None) → Tensor
           

函數傳回輸入矩陣

input

的轉置。交換次元

dim0

dim1

參數:

  • input (Tensor) – 輸入張量,必填
  • dim0 (int) – 轉置的第一維,預設0,可選
  • dim1 (int) – 轉置的第二維,預設1,可選

permute()

permute(dims) → Tensor
           

将tensor的次元換位。

參數:

  • dims (int…*)-換位順序,必填

2. 相同點

  1. 都是傳回轉置後矩陣。
  2. 都可以操作高緯矩陣,

    permute

    在高維的功能性更強。

3.不同點

先定義我們後面用的資料如下

# 創造二維資料x,dim=0時候2,dim=1時候3
x = torch.randn(2,3)       'x.shape  →  [2,3]'
# 創造三維資料y,dim=0時候2,dim=1時候3,dim=2時候4
y = torch.randn(2,3,4)   'y.shape  →  [2,3,4]'
           
  1. 合法性不同

torch.transpose(x)

合法,

x.transpose()

合法。

tensor.permute(x)

不合法,

x.permute()

合法。

參考第二點的舉例

  1. 操作

    dim

    不同:

transpose()

隻能一次操作兩個次元;

permute()

可以一次操作多元資料,且必須傳入所有次元數,因為

permute()

的參數是

int*

舉例

# 對于transpose
x.transpose(0,1)     'shape→[3,2] '  
x.transpose(1,0)     'shape→[3,2] '  
y.transpose(0,1)     'shape→[3,2,4]' 
y.transpose(0,2,1)  'error,操作不了多元'

# 對于permute()
x.permute(0,1)     'shape→[2,3]'
x.permute(1,0)     'shape→[3,2], 注意傳回的shape不同于x.transpose(1,0) '
y.permute(0,1)     "error 沒有傳入所有次元數"
y.permute(1,0,2)  'shape→[3,2,4]'
           
  1. transpose()

    中的

    dim

    沒有數的大小區分;

    permute()

    中的

    dim

    有數的大小區分

舉例,注意後面的

shape

# 對于transpose,不區分dim大小
x1 = x.transpose(0,1)   'shape→[3,2] '  
x2 = x.transpose(1,0)   '也變換了,shape→[3,2] '  
print(torch.equal(x1,x2))
' True ,value和shape都一樣'

# 對于permute()
x1 = x.permute(0,1)     '不同transpose,shape→[2,3] '  
x2 = x.permute(1,0)     'shape→[3,2] '  
print(torch.equal(x1,x2))
'False,和transpose不同'

y1 = y.permute(0,1,2)     '保持不變,shape→[2,3,4] '  
y2 = y.permute(1,0,2)     'shape→[3,2,4] '  
y3 = y.permute(1,2,0)     'shape→[3,4,2] '  
           

4.關于連續contiguous()

經常有人用

view()

函數改變通過轉置後的資料結構,導緻報錯

RuntimeError: invalid argument 2: view size is not compatible with input tensor's....

這是因為tensor經過轉置後資料的記憶體位址不連續導緻的,也就是

tensor . is_contiguous()==False

這時候

reshape()

可以改變該tensor結構,但是

view()

不可以,具體不同可以看view和reshape的差別

例子如下:

x = torch.rand(3,4)
x = x.transpose(0,1)
print(x.is_contiguous()) # 是否連續
'False'
# 再view會發現報錯
x.view(3,4)
'''報錯
RuntimeError: invalid argument 2: view size is not compatible with input tensor's....
'''

# 但是下面這樣是不會報錯。
x = x.contiguous()
x.view(3,4)
           

我們再看看

reshape()

x = torch.rand(3,4)
x = x.permute(1,0) # 等價x = x.transpose(0,1)
x.reshape(3,4)
'''這就不報錯了
說明x.reshape(3,4) 這個操作
等于x = x.contiguous().view()
盡管如此,但是torch文檔中還是不推薦使用reshape
理由是除非為了擷取完全不同但是資料相同的克隆體
'''
           

調用

contiguous()

時,會強制拷貝一份

tensor

,讓它的布局和從頭建立的一毛一樣。

(這一段看文字你肯定不了解,你也可以不用了解,有空我會畫圖補上)

隻需要記住了,每次在使用

view()

之前,該

tensor

隻要使用了

transpose()

permute()

這兩個函數一定要

contiguous()

.

5.總結

最重要的差別應該是上面的第三點和第四個。

另外,簡單的資料用

transpose()

就可以了,但是個人覺得不夠直覺,指向性弱了點;複雜次元的可以用

permute()

,對于次元的改變,一般更加精準。