天天看點

ssd.pytorch源碼分析(三)— 非極大值抑制NMSNMS介紹相關函數複現代碼

NMS源碼

SSD論文連結

NMS介紹

吳恩達對于NMS(非極大值抑制)的介紹:

ssd.pytorch源碼分析(三)— 非極大值抑制NMSNMS介紹相關函數複現代碼

說白了,NMS的作用就是去掉目标檢測任務重複的檢測框。 例如,一個目标有多個選擇框,現在要去掉多餘的選擇框。怎麼做呢?循環執行步驟1和2, 直到隻剩下一個框:

  • 1、選出置信度p_c最高的框;
  • 2、去掉和這個框IOU>0.7的框。

相關函數

一、torch.clamp( )

torch.clamp(input, min, max, out=None) → Tensor
           

将輸入input張量每個元素夾緊到區間 [min,max],并傳回結果到一個新張量。

類似于numpy中的np.clip

操作定義如下:

| min, if x_i < min
y_i = | x_i, if min <= x_i <= max
      | max, if x_i > max
           

參數:

  • input (Tensor) – 輸入張量
  • min (Number) – 限制範圍下限
  • max (Number) – 限制範圍上限
  • out (Tensor, optional) – 輸出張量

例子:

>>> a = torch.randn(4)
>>> a
 1.3869
 0.3912
-0.8634
-0.5468
[torch.FloatTensor of size 4]
>>> torch.clamp(a, min=-0.5, max=0.5)
 0.5000
 0.3912
-0.5000
-0.5000
[torch.FloatTensor of size 4]
           

二、torch.index_select()

torch.index_select(input, dim, index, out=None) → Tensor
           

沿着指定次元對輸入進行切片。

參數:

  • input (Tensor) – 輸入張量
  • dim (int) – 索引的軸
  • index (LongTensor) – 包含索引下标的一維張量
  • out (Tensor, optional) – 目标張量

例子:

>>> x = torch.randn(3, 4)
>>> x

 1.2045  2.4084  0.4001  1.1372
 0.5596  1.5677  0.6219 -0.7954
 1.3635 -1.2313 -0.5414 -1.8478
[torch.FloatTensor of size 3x4]

>>> indices = torch.LongTensor([0, 2])
>>> torch.index_select(x, 0, indices)

 1.2045  2.4084  0.4001  1.1372
 1.3635 -1.2313 -0.5414 -1.8478
[torch.FloatTensor of size 2x4]

>>> torch.index_select(x, 1, indices)

 1.2045  0.4001
 0.5596  0.6219
 1.3635 -0.5414
[torch.FloatTensor of size 3x2]
           

注意,index_select函數中的參數index表示了有哪些索引值是需要保留的。

三、 torch.numel()

傳回input 張量中的元素個數。

複現代碼

以下為ssd.pytorch中NMS(實際上在任何anchor based的目标檢測架構中都适用)。其中:

  • 為了減少計算量,作者僅選取置信度前top_k=200個框;
  • 代碼中包含了IOU的計算。關于IOU計算推薦閱讀這篇文章;
def nms(boxes, scores, overlap=0.7, top_k=200):
    """
    輸入:
        boxes: 存儲一個圖檔的所有預測框。[num_positive,4].
        scores:置信度。如果為多分類則需要将nms函數套在一個循環内。[num_positive].
        overlap: nms抑制時iou的門檻值.
        top_k: 先選取置信度前top_k個框再進行nms.
    傳回:
        nms後剩餘預測框的索引.
    """
    
    keep = scores.new(scores.size(0)).zero_().long() 
    # 儲存留下來的box的索引 [num_positive]
    # 函數new(): 建構一個有相同資料類型的tensor 
    
	#如果輸入box為空則傳回空Tensor
    if boxes.numel() == 0: 
        return keep
        
    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 2]
    y2 = boxes[:, 3]
    area = torch.mul(x2 - x1, y2 - y1) #并行化計算所有框的面積
    v, idx = scores.sort(0)  # 升序排序
    idx = idx[-top_k:]  # 前top-k的索引,從小到大
    xx1 = boxes.new()
    yy1 = boxes.new()
    xx2 = boxes.new()
    yy2 = boxes.new()
    w = boxes.new()
    h = boxes.new()

    count = 0
    while idx.numel() > 0:
        i = idx[-1]  # 目前最大score對應的索引
        keep[count] = i #存儲在keep中
        count += 1
        if idx.size(0) == 1: #跳出循環條件:box被篩選完了
            break
        idx = idx[:-1]  # 去掉最後一個
        
        #剩下boxes的資訊存儲在xx,yy中
        torch.index_select(x1, 0, idx, out=xx1)
        torch.index_select(y1, 0, idx, out=yy1)
        torch.index_select(x2, 0, idx, out=xx2)
        torch.index_select(y2, 0, idx, out=yy2)
        
        # 計算目前最大置信框與其他剩餘框的交集,不知道clamp的同學确實容易被誤導
        xx1 = torch.clamp(xx1, min=x1[i])  #max(x1,xx1)
        yy1 = torch.clamp(yy1, min=y1[i])  #max(y1,yy1)
        xx2 = torch.clamp(xx2, max=x2[i])  #min(x2,xx2)
        yy2 = torch.clamp(yy2, max=y2[i])  #min(y2,yy2)
        w.resize_as_(xx2)
        h.resize_as_(yy2)
        w = xx2 - xx1 #w=min(x2,xx2)−max(x1,xx1)
        h = yy2 - yy1 #h=min(y2,yy2)−max(y1,yy1)
        w = torch.clamp(w, min=0.0) #max(w,0)
        h = torch.clamp(h, min=0.0) #max(h,0)
        inter = w*h
        
		#計算目前最大置信框與其他剩餘框的IOU
        # IoU = i / (area(a) + area(b) - i)
        rem_areas = torch.index_select(area, 0, idx)  # 剩餘的框的面積
        union = rem_areas + area[i]- inter #并集
        IoU = inter/union  # 計算iou
        
        # 選出IoU <= overlap的boxes(注意le函數的使用)
        idx = idx[IoU.le(overlap)]
    return keep,          count
    	   #[num_remain], num_remain