天天看點

ssd.pytorch源碼分析(五)—損失函數及Hard negative mining損失函數總覽定位損失難負樣本挖掘分類損失

MultiBoxLoss源碼

SSD論文連結

本文代碼涉及很多複雜矩陣索引操作,推薦閱讀。

損失函數總覽

在SSD中,預設框default boxes和真實目标ground truth先進行比對。

比對政策細節

然後根據比對到的一對boxes分别計算分類損失和定位損失。

ssd.pytorch源碼分析(五)—損失函數及Hard negative mining損失函數總覽定位損失難負樣本挖掘分類損失

從上面的描述可以看出,可能有多個default boxes比對到一個ground truth的情況。其中α為權重系數,論文和代碼中取1。

代碼中定義了MultiBoxLoss類,其父類為torch.nn.model。"__ init __"函數如下:

class MultiBoxLoss(nn.Module):
    def __init__(self, num_classes, overlap_thresh, prior_for_matching,
                 bkg_label, neg_mining, neg_pos, neg_overlap, encode_target,
                 use_gpu=True):
        #(cfg['num_classes'], 0.5, True, 0, True, 3, 0.5,False, args.cuda)
        super(MultiBoxLoss, self).__init__()
        self.use_gpu = use_gpu
        self.num_classes = num_classes 
        self.threshold = overlap_thresh #比對時需要的iou門檻值
        self.negpos_ratio = neg_pos #需要訓練的負正樣本比例
        self.variance = cfg['variance']
           

forward函數中,為計算損失函數,需要先對資料進行包括比對、正樣本尋找等的操作:

def forward(self, predictions, targets):
    """forward函數第一部分的内容
    輸入:
        predictions (tuple): 一個三元素的元組,包含了預測資訊.
            loc_data [batch,num_priors,4] 所有預設框預測的offsets.
            conf_data [batch,num_priors,num_classes] 所有預測框預測的分類置信度.
            priors [num_priors,4] 所有預設框的位置

        targets [batch,num_objs,5] (last idx is the label).所有真實目标的資訊
        		
    傳回:
    	loss_l, loss_c:定位損失和分類損失
    """
    loc_data, conf_data, priors = predictions
    num = loc_data.size(0) #batchsize
    num_priors = (priors.size(0))
    num_classes = self.num_classes

    """每個default box比對一個gt
    具體分析見:ssd.pytorch源碼分析(四)"""
    #[batch, num_priors, 4] 比對到的真實目标和預設框之間的offset,是learning target
    loc_t = torch.Tensor(num, num_priors, 4)
    #[batch, num_priors] 比對後預設框的類别,是learning target
    conf_t = torch.LongTensor(num, num_priors) 
    
    #對于batch中的每一個圖檔進行比對
    for idx in range(num):
        truths = targets[idx][:, :-1].data #[num_objs,4]
        labels = targets[idx][:, -1].data  #[num_objs,1]
        defaults = priors.data
        match(self.threshold, truths, defaults, self.variance, labels,
              loc_t, conf_t, idx) 
              
    if self.use_gpu:
        loc_t = loc_t.cuda()
        conf_t = conf_t.cuda()
    loc_t = Variable(loc_t, requires_grad=False)
    conf_t = Variable(conf_t, requires_grad=False)
    
	#正樣本查找,等于0為背景 [batch, num_priors]
    pos = conf_t > 0
    num_pos = pos.sum(dim=1, keepdim=True)
    
    """接下來的操作為損失函數的計算等"""
           

定位損失

對于預設框定位,論文還是采取了anchor-based檢測算法中最常用的bounding box回歸,損失函數也采用了和RCNN系列一樣的smooth_l1_loss。

對于下圖公式中的g和l,對應在代碼中分别代表已經encode完成的offset。(g代表一對比對的真實框和預設框的offset,對應代碼中的loc_t,l代表預測框和預設框之間的offset,對應代碼中的loc_p)。

ssd.pytorch源碼分析(五)—損失函數及Hard negative mining損失函數總覽定位損失難負樣本挖掘分類損失
"""forward函數第二部分内容"""
	# Localization Loss (Smooth L1)
	pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)  #[batch, num_priors, 4] 4層都相同 正例索引
	
	#[batch*num_positive, 4] loc_data儲存了所有預設框的predict offset,loc_p儲存其中的正例
	loc_p = loc_data[pos_idx].view(-1, 4)
	#[batch*num_positive, 4] loc_t儲存了所有預設框的target offset,loc_t儲存其中的正例
	loc_t = loc_t[pos_idx].view(-1, 4)
	#計算損失
	loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False)
           

難負樣本挖掘

ssd.pytorch源碼分析(五)—損失函數及Hard negative mining損失函數總覽定位損失難負樣本挖掘分類損失

對預設框與真實對象之間比對後會發現,大部分預設框仍然是背景,正負樣本(所有預設框=正樣本+負樣本)數量差異懸殊。如果将所有預設框拿來訓練,将導緻對負樣本的過拟合。是以隻需要“挖掘”那些分類損失最大的負樣本來訓練,其數量為正樣本的三倍。

注意下面代碼中的loss_c是在難負樣本挖掘中用來給預設框排序的,還不是最終的分類損失loss_class。

"""forward函數第三部分内容"""
	# 難負樣本挖掘的依據:loss_c
	#[batch*num_priors , num_classes]
    batch_conf = conf_data.view(-1, self.num_classes)
    #[batch*num_priors] 計算所有預設框的分類損失
    loss_c = log_sum_exp(batch_conf) - 
    	batch_conf[torch.arange(0,num*num_priors),conf_t.view(-1, 1)] 
    # [batch*num_priors] 因為是給負樣本排序的,是以手動給正樣本損失置0
    loss_c[pos.view(-1, 1)] = 0 
    loss_c = loss_c.view(num, -1) #[N,num_priors]

    # 難負樣本挖掘 Hard Negative Mining
    _, loss_idx = loss_c.sort(1, descending=True)
    _, idx_rank = loss_idx.sort(1) #各個框loss的排名,從大到小 [batch,num_priors]
    num_pos = pos.long().sum(1, keepdim=True) #[batch,1]
    num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)  #[batch,1]
    neg = idx_rank < num_neg.expand_as(idx_rank) #得到負樣本 [batch,num_priors]
           

上面的代碼中涉及了兩次排序的方法。總結了一下:

使用一次sort和兩次sort的差別:

  • 一次sort:得到的index是按順序排的索引
  • 兩次sort:得到原Tensor的映射,排第幾的數字變為排名

了解了兩次sort,這段代碼就不是問題了。

總結:正樣本為預設框與真實框根據iou比對得到,負樣本為分類loss值排序得到。

分類損失

有了正樣本和負樣本,接下來就可以愉快滴計算分類損失了。

"""forward函數第四部分内容"""
	# 首先明确:分類損失包括:n個正樣本損失,3n個負樣本損失
    pos_idx = pos.unsqueeze(2).expand_as(conf_data) #[batch,num_priors,num_classes]
    neg_idx = neg.unsqueeze(2).expand_as(conf_data) #[batch,num_priors,num_classes]
    conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
    targets_weighted = conf_t[(pos+neg).gt(0)]
    loss_class = F.cross_entropy(conf_p, targets_weighted, size_average=False)
    

    N = float(num_pos.data.sum())
    loss_l /= N
    loss_class /= N
    return loss_l, loss_class