關心差别的可以直接看[3.不同點]和[4.連續性問題]
前言
在
pytorch
中轉置用的函數就隻有這兩個
-
transpose()
-
permute()
注意隻有
transpose()
有字尾格式:
transpose_()
:字尾函數的作用是簡化如下代碼:
x = x.transpose(0,1)
等價于
x.transpose_()
# 相當于x = x + 1 簡化為 x+=1
這兩個函數功能相同,有一些在記憶體占用上的細微差別,但是不影響程式設計、可以忽略
1. 官方文檔
transpose()
transpose()
torch.transpose(input, dim0, dim1, out=None) → Tensor
函數傳回輸入矩陣
input
的轉置。交換次元
dim0
和
dim1
參數:
- input (Tensor) – 輸入張量,必填
- dim0 (int) – 轉置的第一維,預設0,可選
- dim1 (int) – 轉置的第二維,預設1,可選
permute()
permute()
permute(dims) → Tensor
将tensor的次元換位。
參數:
- dims (int…*)-換位順序,必填
2. 相同點
- 都是傳回轉置後矩陣。
- 都可以操作高緯矩陣,
在高維的功能性更強。permute
3.不同點
先定義我們後面用的資料如下
# 創造二維資料x,dim=0時候2,dim=1時候3
x = torch.randn(2,3) 'x.shape → [2,3]'
# 創造三維資料y,dim=0時候2,dim=1時候3,dim=2時候4
y = torch.randn(2,3,4) 'y.shape → [2,3,4]'
- 合法性不同
torch.transpose(x)
合法,
x.transpose()
合法。
tensor.permute(x)
不合法,
x.permute()
合法。
參考第二點的舉例
- 操作
不同:dim
transpose()
隻能一次操作兩個次元;
permute()
可以一次操作多元資料,且必須傳入所有次元數,因為
permute()
的參數是
int*
。
舉例
# 對于transpose
x.transpose(0,1) 'shape→[3,2] '
x.transpose(1,0) 'shape→[3,2] '
y.transpose(0,1) 'shape→[3,2,4]'
y.transpose(0,2,1) 'error,操作不了多元'
# 對于permute()
x.permute(0,1) 'shape→[2,3]'
x.permute(1,0) 'shape→[3,2], 注意傳回的shape不同于x.transpose(1,0) '
y.permute(0,1) "error 沒有傳入所有次元數"
y.permute(1,0,2) 'shape→[3,2,4]'
-
中的transpose()
沒有數的大小區分;dim
中的permute()
有數的大小區分dim
舉例,注意後面的
shape
:
# 對于transpose,不區分dim大小
x1 = x.transpose(0,1) 'shape→[3,2] '
x2 = x.transpose(1,0) '也變換了,shape→[3,2] '
print(torch.equal(x1,x2))
' True ,value和shape都一樣'
# 對于permute()
x1 = x.permute(0,1) '不同transpose,shape→[2,3] '
x2 = x.permute(1,0) 'shape→[3,2] '
print(torch.equal(x1,x2))
'False,和transpose不同'
y1 = y.permute(0,1,2) '保持不變,shape→[2,3,4] '
y2 = y.permute(1,0,2) 'shape→[3,2,4] '
y3 = y.permute(1,2,0) 'shape→[3,4,2] '
4.關于連續contiguous()
經常有人用
view()
函數改變通過轉置後的資料結構,導緻報錯
RuntimeError: invalid argument 2: view size is not compatible with input tensor's....
這是因為tensor經過轉置後資料的記憶體位址不連續導緻的,也就是
tensor . is_contiguous()==False
這時候
reshape()
可以改變該tensor結構,但是
view()
不可以,具體不同可以看view和reshape的差別
例子如下:
x = torch.rand(3,4)
x = x.transpose(0,1)
print(x.is_contiguous()) # 是否連續
'False'
# 再view會發現報錯
x.view(3,4)
'''報錯
RuntimeError: invalid argument 2: view size is not compatible with input tensor's....
'''
# 但是下面這樣是不會報錯。
x = x.contiguous()
x.view(3,4)
我們再看看
reshape()
x = torch.rand(3,4)
x = x.permute(1,0) # 等價x = x.transpose(0,1)
x.reshape(3,4)
'''這就不報錯了
說明x.reshape(3,4) 這個操作
等于x = x.contiguous().view()
盡管如此,但是torch文檔中還是不推薦使用reshape
理由是除非為了擷取完全不同但是資料相同的克隆體
'''
調用
contiguous()
時,會強制拷貝一份
tensor
,讓它的布局和從頭建立的一毛一樣。
(這一段看文字你肯定不了解,你也可以不用了解,有空我會畫圖補上)
隻需要記住了,每次在使用
view()
之前,該
tensor
隻要使用了
transpose()
和
permute()
這兩個函數一定要
contiguous()
.
5.總結
最重要的差別應該是上面的第三點和第四個。
另外,簡單的資料用
transpose()
就可以了,但是個人覺得不夠直覺,指向性弱了點;複雜次元的可以用
permute()
,對于次元的改變,一般更加精準。