天天看點

Pytorch中的contiguous了解

最近遇到這個函數,但查的中文部落格裡的解釋貌似不是很到位,這裡翻譯一下stackoverflow上的回答并加上自己的了解。

在pytorch中,隻有很少幾個操作是不改變tensor的内容本身,而隻是重新定義下标與元素的對應關系的。換句話說,這種操作不進行資料拷貝和資料的改變,變的是中繼資料。

這些操作是:

narrow(),view(),expand()和transpose()

舉個栗子,在使用

transpose()

進行轉置操作時,pytorch并不會建立新的、轉置後的tensor,而是修改了tensor中的一些屬性(也就是中繼資料),使得此時的offset和stride是與轉置tensor相對應的。轉置的tensor和原tensor的記憶體是共享的!

為了證明這一點,我們來看下面的代碼:

x = torch.randn(, )
y = x.transpose(x, , )
x[, ] = 
print(y[, ])
# print 233
           

可以看到,改變了y的元素的值的同時,x的元素的值也發生了變化。

也就是說,經過上述操作後得到的tensor,它内部資料的布局方式和從頭開始建立一個這樣的正常的tensor的布局方式是不一樣的!于是…這就有

contiguous()

的用武之地了。

在上面的例子中,x是contiguous的,但y不是(因為内部資料不是通常的布局方式)。注意不要被contiguous的字面意思“連續的”誤解,tensor中資料還是在記憶體中一塊區域裡,隻是布局的問題!

當調用

contiguous()

時,會強制拷貝一份tensor,讓它的布局和從頭建立的一毛一樣。

一般來說這一點不用太擔心,如果你沒在需要調用

contiguous()

的地方調用

contiguous()

,運作時會提示你:

RuntimeError: input is not contiguous

隻要看到這個錯誤提示,加上

contiguous()

就好啦~

繼續閱讀