天天看點

PyTorch中Numpy,Tensor與Variable深入了解與轉換技巧

PyTorch為了實作GPU加速功能,在Numpy的基礎上,引入了Tensor,為了實作自動求導功能,引入了Variable。我們一般讀取的資料都是以Numpy Array方式的。在TensorFlow,Numpy的資料會在輸入網絡後自動轉換為Tensor,一般不需要我們進行顯性操作,當然偶爾也會有例外。但是在PyTorch,需要我們自己進行顯性操作才可以的。

下面我以一個網絡訓練的過程來講解它們之間如何進行互相的轉換。

首先我們會讀取Numpy的資料,為了能夠送入網絡,使用GPU計算加速,是以要進行Numpy2Tensor操作,由于網絡輸入輸出都是Variable,我們還需要Tensor2Variable。在訓練的過程中,我們需要取出loss的值,由于loss參與了backward(),是以此時的loss已經變成了Variable,我們取出loss時需要取出的是Tensor。同樣的,如果我想取出網絡輸出的結果時,由于網絡輸入輸出都是Variable,也需要執行Variable2Tensor,如果進一步我們想把loss顯示出來,就需要Tensor2Numpy。

總結一下,真正和我們開發人員直接接觸的是Numpy資料,需要送入網絡時進行Numpy2Tensor,如果一些Tensor作為參數需要求解梯度資訊時進行Tensor2Variable。需要從Variable取資料時,使用Variable2Tensor。對Tensor進行讀取操作時需要Tensor2Numpy。

轉換方法

Numpy2Tensor:

  1. torch.from_numpy(Numpy_data)
  2. torch.tensor(Numpy_data)

Tensor2Variable:

  1. Variable(Tensor_data)

Variable2Tensor:

  1. Variable_data.data()

Tensor2Numpy :

  1. Tensor_data.numpy()

注意一點,Numpy與Variable無法直接轉換,需要經過Tensor作為中介。