天天看點

PyTorch碎片:深刻透徹了解Torch中Tensor.contiguous()函數1.函數定義2.定義了解3.資料案例分析4.總結

1.函數定義

Returns a contiguous tensor containing the same data as self tensor.

傳回一個與原始tensor相同元素資料的 “連續”tensor類型

If self tensor is contiguous, this function returns the self tensor.

如果原始tensor本身就是連續的,則傳回原始tensor

2.定義了解

定義本身有兩個重要的點:

  • 對原始tensor進行複制
  • 傳回contiguous“類型”的一個tensor

Tensor.contiguous()函數不會對原始資料進行任何修改,而僅僅對其進行複制,并在記憶體空間上進行對齊,即在記憶體空間上,tensor元素的記憶體位址保持連續。

這麼做的目的是,在對tensor元素進行轉換和次元變換等操作之後,元素位址在記憶體空間中保證連續性,在後續利用指針對tensor元素進行讀取時,能夠減少讀取便利,提高記憶體空間優化。

3.資料案例分析

import torch
src_t = torch.randn((2,3))
print(src_t.shape)
print(src_t.is_contiguous())
           

輸出:

>>> torch.Size([2, 3])
>>> True
           

可以看出,在利用torch.randn函數進行tensor建立時,擷取的tensor元素位址是連續記憶體空間儲存的。那麼,如果對建立的tensor進行transpose變換操作:

trans_t = src_t.transpose(0,1) 
print(trans_t.shape)
print(trans_t.is_contiguous())
           

輸出:

>>> torch.Size([3, 2])
>>> False
           

我們發現經過transpose變換以後,tensor變成非連續儲存類型(uncontiguous)。

那麼,變成這種非連續儲存類型會造成什麼樣的後果呢?

簡單的以view函數為例:

當嘗試對uncontiguous類型tensor進行次元變換時,就會出現下面錯誤:

Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
           

錯誤提示告訴我們,至少有一個次元資料在記憶體空間上跨越了兩個連續子空間!此時,我們輸出trans_t的連續儲存類型是什麼:

是以,為了能夠實作對張量trans_t的次元變換,需要先對tensor進行contiguous記憶體位址對齊操作,然後再進行view操作:

print(trans_t.shape)
trans_t.contiguous().view(-1,3)
print(trans_t.shape)
           
>>> torch.Size([3, 2])
>>> torch.Size([2, 3])
           

4.總結

總結一下,為了保證代碼的可讀性和嚴謹性,當對tensor進行次元變化時,常需要配合contiguous函數使用,但是哪些函數會造成原始tensor變的uncontiguous呢?

  • transpose()
  • narrow()
  • expand()

有其他函數,我會進一步補充,有錯誤歡迎指正,謝謝!

繼續閱讀