天天看點

Expected object of scalar type Long but got scalar type Double for argument #2 ‘target‘

1.pytorch報錯:

loss_class = torch.nn.CrossEntropyLoss()
s_data, s_label = data_source[0].to(DEVICE), data_source[1].to(DEVICE)
class_output, domain_output = model(input_data=s_data.float(), alpha=alpha)
# 報錯位置如下:
err_s_label = loss_class(class_output, s_label)      

報錯内容如下:

Expected object of scalar type Long but got scalar type Double for argument #2 'target'

表示第二個位置的參數要求是Long類型,然而傳入的時候是Double類型,是以我們隻需:

s_label.long()      

即可。

2. 如果會繼續出現報錯:

RuntimeError: multi-target not supported at /opt/conda/conda-bld/pytorch_1556653099582/work/aten/src/THCUNN/generic/ClassNLLCriterion.cu:15

表示在計算loss過程中遇到了多輸出預測值。或者标簽的次元是不同的(我這邊标簽的shape是(128, 1)),我們隻需要将标簽squeeze就行,具體參考torch.squeeze()函數,個人變動方法:

err_s_label = loss_class(class_output, s_label.squeeze(1).long())      

3.總結:torch中資料類型的變化:

torch資料類型轉換

資料類型 資料長度 用法
int int32 torch.int()
long int64 torch.long()
float float32 torch.float()
double