天天看點

Yolov3模型架構darknet研究(三)darknet推理部分代碼簡要梳理

darknet代碼分析先從Inference(推理)部分講起吧,畢竟隻有forward部分,而且比backward要簡單一些。

首先聲明,我分析的代碼是基于darknentAB版本,非官方版本, 畢竟他倆的代碼還是有稍許不一樣。

最上層的入口函數自然是darknet.c中main(),然後就是run_detector(...) 它是training和Inference功能實作接口。再次重申,本博文講的是後者,即推理部分,其對應的子接口是test_detector(。。。)

好了,正文剛剛開始,下面重點分析test_detector(。。。)的實作代碼。

1)準備部分代碼如下,我在裡面對每一行都添加了具體注釋

看代碼注釋前,講解一下

調用darknet推理接口的形式是:

./darknet detector test cfg/xxx.data cfg/xxx.cfg yolov3.weights data/xxx.jpg

這裡,./darknet是unix的可執行檔案,在windows下就是字尾exe檔案了。調用它就會進入main, 
其後幾個都是參數,存放到main的argv裡面, 下面分别講解:
detector和test參數  是作為flag來保證程式執行進入test_detector()


cfg/xxx.data  是來描述目标類型名字定義在哪個檔案,訓練用的樣本圖檔和label檔案在哪裡,總之,訓練時這個檔案就應該ready
cfg/xxx.cfg  用來描述網絡層次結構以及每一個結構都有哪些成員及參數, 非常重要的配置檔案,訓練時也應該ready

yolov3.weights  這個是訓練好的權值檔案
data/xxx.jpg  待推理的目标圖檔,位址及檔案名由客戶指定
           
//把data config檔案内容讀到options連結清單結構裡面
   list *options = read_data_cfg(datacfg);
    //通過 names這個key,找到其對應的描述class names的data檔案
    char *name_list = option_find_str(options, "names", "data/names.list");
    int names_size = 0;
    //讀取所有class names
    char **names = get_labels_custom(name_list, &names_size); //get_labels(name_list);
    //讀取字母如a、b、c、d等字母所對應的小圖檔,将來會被目标檢測框上的字元顯示所調用
    image **alphabet = load_alphabet();
    //parse 描述該算法模型的網絡結構的config檔案并指派給net變量
    //注意  network是非常關鍵的變量類型,用來描述模型的網絡結構
    network net = parse_network_cfg_custom(cfgfile, 1); // set batch=1
    //根據net網絡結構,把weightfile二進制檔案正确的賦給net結構裡面的每一個成員
    if (weightfile) {
        load_weights(&net, weightfile);
    }
           

2)另外一個小鋪墊如下,即講batch normal的輸入權值進行标準化。這裡面由個疑問是,按定義,BN應該是輸出值進行标準化,而不是權值,此外,官方版本darknet源碼沒有看到這個函數的存在。

fuse_conv_batchnorm(net);
           

 3)開始resize成nework size并計算每一層的輸出值

//忍不住,吐槽一下 image resize是簡單的雙線性插值,為了更好的識别效果,可以考慮采用更好的插值算法。
         image sized = resize_image(im, net.w, net.h);
         。。。 。。。
        //X就是上面的 reized image data
        //net就是network類型網絡結構變量,而且訓練好的權值已經正确的指派給net裡面各個layer的成   
        //員。  然後開始根據輸入和權值開始計算每一層的輸出并存放在l.output變量裡面
        //這個是最耗費時間的函數,往往計算性能優化就在 W.X裡面
        network_predict(net, X);
           

4)計算完後,最後開始找detections,并用threshold來過濾掉不合适的目标檢測。

detection *dets = get_network_boxes(&net, im.w, im.h, thresh, hier_thresh, 0, 1, &nboxes, letterbox);
        if (nms) do_nms_sort(dets, nboxes, l.classes, nms);
        draw_detections_v3(im, dets, nboxes, thresh, names, alphabet, l.classes, ext_output);
           

get_network_boxes調用了兩個重要函數:

//使用3個yolo層19x19 38x38 76x76來分别檢測目标檢測的置信度是否大于0.25
   //如果大于,則記錄下來,最後傳回給dets
    detection *dets = make_network_boxes(net, thresh, num);
    //對檢測出來的dets的每個結構變量進行成員指派,包括機率,類型名稱等
    fill_network_boxes(net, w, h, thresh, hier, map, relative, dets, letter);
           

do_nms_sort()是檢測出來的目标,進行非最大化抑制。即先找出最大機率的目标,然後将iou比較大的其它目标檢測機率置成0. 

最後調用draw_detections_v3()來對每個detection進行檢查,看它對應的哪個類型的機率最高,而且必須超過設定的threshold,才記錄下來作為final detect results。 如果不超過threshold 就被放棄。 上面的接口參數可以加 -thresold來指定,否則調用預設threshold(0.48)。  最後的最後 将這些final detects畫在圖檔上。 

繼續閱讀