天天看点

pytorch筛选统计

大于统计:

iou是可以的

n_correct= int(sum(iou > 0.5))

但是这个是不对的:

aa= torch.sigmoid(conf_preds[:, :, 0][pos_mask])

n_correct= int(sum(aa> 0.5)

正确的:

  count=int(con_ref[con_ref>0.8].size(0))

这个也是正确的:

import torch
idx=torch.Tensor(([0.1,0.2,0.5,0.6],[0.05,0.3,0.7,0.8]))


pos = idx > 1.2

print(idx[pos])
num_pos = int(pos.sum())

print(num_pos)

num_pos = pos.sum(1, keepdim=True)

print(num_pos)

num_pos = pos.long().sum(1, keepdim=True)#正样本第2维度个数统计
print(num_pos)