MultiBoxLoss源碼
SSD論文連結
本文代碼涉及很多複雜矩陣索引操作,推薦閱讀。
損失函數總覽
在SSD中,預設框default boxes和真實目标ground truth先進行比對。
比對政策細節
然後根據比對到的一對boxes分别計算分類損失和定位損失。

從上面的描述可以看出,可能有多個default boxes比對到一個ground truth的情況。其中α為權重系數,論文和代碼中取1。
代碼中定義了MultiBoxLoss類,其父類為torch.nn.model。"__ init __"函數如下:
class MultiBoxLoss(nn.Module):
def __init__(self, num_classes, overlap_thresh, prior_for_matching,
bkg_label, neg_mining, neg_pos, neg_overlap, encode_target,
use_gpu=True):
#(cfg['num_classes'], 0.5, True, 0, True, 3, 0.5,False, args.cuda)
super(MultiBoxLoss, self).__init__()
self.use_gpu = use_gpu
self.num_classes = num_classes
self.threshold = overlap_thresh #比對時需要的iou門檻值
self.negpos_ratio = neg_pos #需要訓練的負正樣本比例
self.variance = cfg['variance']
forward函數中,為計算損失函數,需要先對資料進行包括比對、正樣本尋找等的操作:
def forward(self, predictions, targets):
"""forward函數第一部分的内容
輸入:
predictions (tuple): 一個三元素的元組,包含了預測資訊.
loc_data [batch,num_priors,4] 所有預設框預測的offsets.
conf_data [batch,num_priors,num_classes] 所有預測框預測的分類置信度.
priors [num_priors,4] 所有預設框的位置
targets [batch,num_objs,5] (last idx is the label).所有真實目标的資訊
傳回:
loss_l, loss_c:定位損失和分類損失
"""
loc_data, conf_data, priors = predictions
num = loc_data.size(0) #batchsize
num_priors = (priors.size(0))
num_classes = self.num_classes
"""每個default box比對一個gt
具體分析見:ssd.pytorch源碼分析(四)"""
#[batch, num_priors, 4] 比對到的真實目标和預設框之間的offset,是learning target
loc_t = torch.Tensor(num, num_priors, 4)
#[batch, num_priors] 比對後預設框的類别,是learning target
conf_t = torch.LongTensor(num, num_priors)
#對于batch中的每一個圖檔進行比對
for idx in range(num):
truths = targets[idx][:, :-1].data #[num_objs,4]
labels = targets[idx][:, -1].data #[num_objs,1]
defaults = priors.data
match(self.threshold, truths, defaults, self.variance, labels,
loc_t, conf_t, idx)
if self.use_gpu:
loc_t = loc_t.cuda()
conf_t = conf_t.cuda()
loc_t = Variable(loc_t, requires_grad=False)
conf_t = Variable(conf_t, requires_grad=False)
#正樣本查找,等于0為背景 [batch, num_priors]
pos = conf_t > 0
num_pos = pos.sum(dim=1, keepdim=True)
"""接下來的操作為損失函數的計算等"""
定位損失
對于預設框定位,論文還是采取了anchor-based檢測算法中最常用的bounding box回歸,損失函數也采用了和RCNN系列一樣的smooth_l1_loss。
對于下圖公式中的g和l,對應在代碼中分别代表已經encode完成的offset。(g代表一對比對的真實框和預設框的offset,對應代碼中的loc_t,l代表預測框和預設框之間的offset,對應代碼中的loc_p)。
"""forward函數第二部分内容"""
# Localization Loss (Smooth L1)
pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) #[batch, num_priors, 4] 4層都相同 正例索引
#[batch*num_positive, 4] loc_data儲存了所有預設框的predict offset,loc_p儲存其中的正例
loc_p = loc_data[pos_idx].view(-1, 4)
#[batch*num_positive, 4] loc_t儲存了所有預設框的target offset,loc_t儲存其中的正例
loc_t = loc_t[pos_idx].view(-1, 4)
#計算損失
loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False)
難負樣本挖掘
對預設框與真實對象之間比對後會發現,大部分預設框仍然是背景,正負樣本(所有預設框=正樣本+負樣本)數量差異懸殊。如果将所有預設框拿來訓練,将導緻對負樣本的過拟合。是以隻需要“挖掘”那些分類損失最大的負樣本來訓練,其數量為正樣本的三倍。
注意下面代碼中的loss_c是在難負樣本挖掘中用來給預設框排序的,還不是最終的分類損失loss_class。
"""forward函數第三部分内容"""
# 難負樣本挖掘的依據:loss_c
#[batch*num_priors , num_classes]
batch_conf = conf_data.view(-1, self.num_classes)
#[batch*num_priors] 計算所有預設框的分類損失
loss_c = log_sum_exp(batch_conf) -
batch_conf[torch.arange(0,num*num_priors),conf_t.view(-1, 1)]
# [batch*num_priors] 因為是給負樣本排序的,是以手動給正樣本損失置0
loss_c[pos.view(-1, 1)] = 0
loss_c = loss_c.view(num, -1) #[N,num_priors]
# 難負樣本挖掘 Hard Negative Mining
_, loss_idx = loss_c.sort(1, descending=True)
_, idx_rank = loss_idx.sort(1) #各個框loss的排名,從大到小 [batch,num_priors]
num_pos = pos.long().sum(1, keepdim=True) #[batch,1]
num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1) #[batch,1]
neg = idx_rank < num_neg.expand_as(idx_rank) #得到負樣本 [batch,num_priors]
上面的代碼中涉及了兩次排序的方法。總結了一下:
使用一次sort和兩次sort的差別:
- 一次sort:得到的index是按順序排的索引
- 兩次sort:得到原Tensor的映射,排第幾的數字變為排名
了解了兩次sort,這段代碼就不是問題了。
總結:正樣本為預設框與真實框根據iou比對得到,負樣本為分類loss值排序得到。
分類損失
有了正樣本和負樣本,接下來就可以愉快滴計算分類損失了。
"""forward函數第四部分内容"""
# 首先明确:分類損失包括:n個正樣本損失,3n個負樣本損失
pos_idx = pos.unsqueeze(2).expand_as(conf_data) #[batch,num_priors,num_classes]
neg_idx = neg.unsqueeze(2).expand_as(conf_data) #[batch,num_priors,num_classes]
conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
targets_weighted = conf_t[(pos+neg).gt(0)]
loss_class = F.cross_entropy(conf_p, targets_weighted, size_average=False)
N = float(num_pos.data.sum())
loss_l /= N
loss_class /= N
return loss_l, loss_class