天天看點

常用寫法小總結1 Numpy,Tensor,CPU,GPU對象之間的互相轉換

1 Numpy,Tensor,CPU,GPU對象之間的互相轉換

1)導入需要的子產品 、

import torch
import numpy as np
from torch.autograd import Variable 
           

2)tensor間的轉換

a = torch.ones(2,3) # 建立全為1的tensor
print("a:",a)
float_a = a.data.float() # 轉為FloatTensor
print("float_a:",float_a)
int_a = a.type(torch.IntTensor) # 使用type()函數轉為指定類型的tensor
print("int_a:",int_a)

# b為DoubleTensor
b = torch.eye(2,3).data.double()
print("b:",b)
# 不知轉換為什麼類型時,可将其轉換為已知某個資料的類型
a_ = a.type_as(b)
print("a_類型:",a_.type())
print("a_:",a_)
           
a: tensor([[ 1.,  1.,  1.],
        [ 1.,  1.,  1.]])
float_a: tensor([[ 1.,  1.,  1.],
        [ 1.,  1.,  1.]])
int_a: tensor([[ 1,  1,  1],
        [ 1,  1,  1]], dtype=torch.int32)
b: tensor([[ 1.,  0.,  0.],
        [ 0.,  1.,  0.]], dtype=torch.float64)
a_類型: torch.DoubleTensor
a_: tensor([[ 1.,  1.,  1.],
        [ 1.,  1.,  1.]], dtype=torch.float64)
           

 3)CPU <-> GPU

print("GPU可用數目:",torch.cuda.device_count())
# CPU張量->GPU
var = torch.Tensor(2,3)
if torch.cuda.is_available():
    var = var.cuda()
print("var:",var)
# GPU張量->CPU
# 直接從cuda中擷取資料,會出錯
#var = var.cuda().data.numpy()
var = var.cuda().data.cpu().numpy()
           
GPU可用數目: 0
           

4)tensor <-> numpy

# tensor和numpy對象共享記憶體,之間轉換很快
# numpy->tensor
a = np.ones((2,3))
a_tensor = torch.from_numpy(a)
print("a:",a)
print("a_tensor:",a_tensor)

# tensor->numpy
b = a_tensor.numpy()
print("b:",b)
           
a: [[1. 1. 1.]
 [1. 1. 1.]]
a_tensor: tensor([[ 1.,  1.,  1.],
        [ 1.,  1.,  1.]], dtype=torch.float64)
b: [[1. 1. 1.]
 [1. 1. 1.]]
           

5)Variable

# Variable簡單封裝了tensor,并支援幾乎所有Tensor
var_tensor = Variable(torch.Tensor(2,3))
print("var_tensor:",var_tensor)
# Variable<->numpy之間的轉換
var_numpy = var_tensor.data.numpy()
var_to_tensor = Variable(torch.from_numpy(var_numpy))
print("var_numpy:",var_numpy)
print("var_to_tensor:",var_to_tensor)
           
var_tensor: tensor(1.00000e-39 *
       [[ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  9.4592,  0.0000]])
var_numpy: [[4.203895e-45 0.000000e+00 1.401298e-45]
 [0.000000e+00 9.459202e-39 0.000000e+00]]
var_to_tensor: tensor(1.00000e-39 *
       [[ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  9.4592,  0.0000]])
           

繼續閱讀