天天看點

pytorch中contiguous()的功能

方法介紹

torch.view()方法對張量改變“形狀”其實并沒有改變張量在記憶體中真正的形狀。

簡單地說,view方法沒有拷貝新的張量,沒有開辟新記憶體,與原張量共享記憶體,隻是重新定義了通路張量的規則,使得取出的張量按照我們希望的形狀展現。

舉例說,如下代碼:

t = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 8], [9, 10, 11, 12]])
t2 = t.transpose(0, 1)
print(t2)
tensor([[ 0,  4,  9],
        [ 1,  5, 10],
        [ 2,  6, 11],
        [ 3,  8, 12]])
t3 = t2.view(2, 6)
print(t3)
# 報錯原因:改變了形狀的t2語義上是4行3列的,在記憶體中還是跟t一樣,沒有改變,導緻如果按照語義的形狀進行view拉伸,數字不連續。
File "E:/.../test3.py", line 109, in <module>
    t3 = t2.view(2, 6)
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
           

那怎麼辦呢?此時torch.contiguous()方法就派上用場了。

看如下代碼:

t4 = t.transpose(0, 1)
print(t4)
tensor([[ 0,  4,  9],
        [ 1,  5, 10],
        [ 2,  6, 11],
        [ 3,  8, 12]])
t5 = t4.contiguous()  # 重點說三遍!重點說三遍!重點說三遍!
t6 = t5.view(2, 6)
print(t6)
tensor([[ 0,  4,  9,  1,  5, 10],
        [ 2,  6, 11,  3,  8, 12]])
           

總結

view隻能用在contiguous的variable上。如果在view之前用了transpose, permute等,需要用contiguous()來傳回一個contiguous copy。

還有可能的解釋是:

有些tensor并不是占用一整塊記憶體,而是由不同的資料塊組成,而tensor的view()操作依賴于記憶體是整塊的,這時隻需要執行contiguous()這個函數,把tensor變成在記憶體中連續分布的形式。

判斷是否contiguous用**torch.Tensor.is_contiguous()**函數。

torch.contiguous()方法首先拷貝了一份張量在記憶體中的位址,然後将位址按照形狀改變後的張量的語義進行排列。就是說contiguous()方法改變了多元數組在記憶體中的存儲順序,以便配合view方法使用。

在pytorch的0.4版本中,增加了torch.reshape(), 這與 numpy.reshape 的功能類似。它大緻相當于 tensor.contiguous().view()

繼續閱讀