天天看點

pytorch篩選統計

大于統計:

iou是可以的

n_correct= int(sum(iou > 0.5))

但是這個是不對的:

aa= torch.sigmoid(conf_preds[:, :, 0][pos_mask])

n_correct= int(sum(aa> 0.5)

正确的:

  count=int(con_ref[con_ref>0.8].size(0))

這個也是正确的:

import torch
idx=torch.Tensor(([0.1,0.2,0.5,0.6],[0.05,0.3,0.7,0.8]))


pos = idx > 1.2

print(idx[pos])
num_pos = int(pos.sum())

print(num_pos)

num_pos = pos.sum(1, keepdim=True)

print(num_pos)

num_pos = pos.long().sum(1, keepdim=True)#正樣本第2次元個數統計
print(num_pos)