天天看點

py-faster-rcnn源碼解析之處理訓練資料

因為最近在使用py-faster-rcnn訓練自己的資料時報如下錯:

roidb[i]['image'] = imdb.image_path_at(i) 
IndexError: list index out of range 
           

看了網上的很多說法都是讓删除py-faster-rcnn/data/cache下的pkl檔案,但是該方法對我并沒有起作用,于是就将py-faster-rcnn處理訓練資料部分的代碼跟蹤了一下,這裡和大家一起分享,也做個記錄。

下面的解說都是以py-faster-rcnn目錄為根目錄,後面就不再重複了。

我是通過執行scripts/faster_rcnn_alt_opt.sh來訓練模型的,從該腳本的第46行代碼:

time ./tools/train_faster_rcnn_alt_opt.py --gpu ${GPU_ID} \
           

我們可以知道模型是通過tools/train_faster_rcnn_alt_opt.py進行訓練的,接下裡我們就去看這個py檔案的源碼。

模型訓練分為兩個Stage,每個Stage都是從RPN訓練開始的,是以我們直接看train_rpn函數:

def train_rpn(queue=None, imdb_name=None, init_model=None, solver=None,
              max_iters=None, cfg=None):
    """Train a Region Proposal Network in a separate training process.
    """

    # Not using any proposals, just ground-truth boxes
    cfg.TRAIN.HAS_RPN = True
    cfg.TRAIN.BBOX_REG = False  # applies only to Fast R-CNN bbox regression
    cfg.TRAIN.PROPOSAL_METHOD = 'gt'
    cfg.TRAIN.IMS_PER_BATCH = 1
    print 'Init model: {}'.format(init_model)
    print('Using config:')
    pprint.pprint(cfg)

    import caffe
    _init_caffe(cfg)

    roidb, imdb = get_roidb(imdb_name)
    print 'roidb len: {}'.format(len(roidb))
    output_dir = get_output_dir(imdb)
    print 'Output will be saved to `{:s}`'.format(output_dir)

    model_paths = train_net(solver, roidb, output_dir,
                            pretrained_model=init_model,
                            max_iters=max_iters)
    # Cleanup all but the final model
    for i in model_paths[:-1]:
        os.remove(i)
    rpn_model_path = model_paths[-1]
    # Send final model path through the multiprocessing queue
    queue.put({'model_path': rpn_model_path})
           

前面的幾行代碼是進行訓練的配置,一直到這幾行代碼開始準備資料:

roidb, imdb = get_roidb(imdb_name)
    print 'roidb len: {}'.format(len(roidb))
    output_dir = get_output_dir(imdb)
    print 'Output will be saved to `{:s}`'.format(output_dir)
           

是以我們再跳去get_roidb函數去看它是如何實作的:

def get_roidb(imdb_name, rpn_file=None):
    imdb = get_imdb(imdb_name)
    print 'Loaded dataset `{:s}` for training'.format(imdb.name)
    imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
    print 'Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD)
    if rpn_file is not None:
        imdb.config['rpn_file'] = rpn_file
    roidb = get_training_roidb(imdb)
    return roidb, imdb
           

Stage1 RPN, init from ImageNet model

時輸入參數imdb_name是voc_2007_trainval,rpn_file是None。從這個函數我們能夠得到的資訊是roidb是與imdb相關的,下面我們先看imdb是怎麼得到的,即先看get_imdb函數,這個函數的代碼在lib/datasets/factory.py中:

def get_imdb(name):
    """Get an imdb (image database) by name."""
    if not __sets.has_key(name):
        raise KeyError('Unknown dataset: {}'.format(name))
    return __sets[name]()
           

檔案的開頭處代碼

__sets={}

将__sets定義為字典,而由get_imdb函數的傳回值我們可以看到該字典的key是imdb名稱,而value是個匿名函數,因為我們使用的voc_2007_trainval的資料,是以往上看這段代碼:

# Set up voc_<year>_<split> using selective search "fast" mode
for year in ['2007', '2012']:
    for split in ['train', 'val', 'trainval', 'test']:
        name = 'voc_{}_{}'.format(year, split)
        __sets[name] = (lambda split=split, year=year: pascal_voc(split, year))
           

可以看到上面提到的匿名函數實際上是調用了pascal_voc(split, year),然後我們進入lib/datasets/pascal_voc.py,發現它實際上是pascal_voc的構造函數,這個類繼承了imdb類。

class pascal_voc(imdb):
    def __init__(self, image_set, year, devkit_path=None):
        imdb.__init__(self, 'voc_' + year + '_' + image_set)
        self._year = year
        self._image_set = image_set
        self._devkit_path = self._get_default_path() if devkit_path is None \
                            else devkit_path
        self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)
        self._classes = ('__background__', # always index 0
                         'aeroplane', 'bicycle', 'bird', 'boat',
                         'bottle', 'bus', 'car', 'cat', 'chair',
                         'cow', 'diningtable', 'dog', 'horse',
                         'motorbike', 'person', 'pottedplant',
                         'sheep', 'sofa', 'train', 'tvmonitor')
        self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))
        self._image_ext = '.jpg'
        self._image_index = self._load_image_set_index()
        # Default to roidb handler
        self._roidb_handler = self.selective_search_roidb
        self._salt = str(uuid.uuid4())
        self._comp_id = 'comp4'

        # PASCAL specific config options
        self.config = {'cleanup'     : True,
                       'use_salt'    : True,
                       'use_diff'    : False,
                       'matlab_eval' : False,
                       'rpn_file'    : None,
                       'min_size'    : 2}

        assert os.path.exists(self._devkit_path), \
                'VOCdevkit path does not exist: {}'.format(self._devkit_path)
        assert os.path.exists(self._data_path), \
                'Path does not exist: {}'.format(self._data_path)
           

這個代碼大部分隻是進行了一些配置,但有兩行代碼需要注意。

一行是

self._image_index = self._load_image_set_index

,我們把它列印出來,可以看到它裡面是圖檔的名稱,但去除了字尾和路徑的其他部分,說明它是後面要用來加載圖檔和xml檔案的。

還有一行是

self._roidb_handler = self.selective_search_roidb

,跳轉到selective_search_roidb方法,然後一路跟蹤,先是它調用了gt_roidb方法,然後它調用了_load_pascal_annotation方法,在這個方法裡的第一句代碼就加載了xml檔案(xml檔案存放了需要檢測的目标所在的區域資訊)。

這個函數最後得到了一個鍵為

boxes

gt_classes

gt_overlaps

flipped

seg_areas

的字典,後面我們還會提到。

對于pascal_voc類的分析就到這裡,下面我們看一下pascal_voc的父類imdb,它的實作在lib/dataset/imdb.py中:

class imdb(object):
    """Image database."""

    def __init__(self, name):
        self._name = name
        self._num_classes = 0
        self._classes = []
        self._image_index = []
        self._obj_proposer = 'selective_search'
        self._roidb = None
        self._roidb_handler = self.default_roidb
        # Use this dict for storing dataset specific config options
        self.config = {}
           

這裡,我們能得到的最主要的資訊是roidb是imdb的成員變量。而roidb裡包含哪些資訊其實在下面的代碼中也有說明:

@property
    def roidb(self):
        # A roidb is a list of dictionaries, each with the following keys:
        #   boxes
        #   gt_overlaps
        #   gt_classes
        #   flipped
        if self._roidb is not None:
            return self._roidb
        self._roidb = self.roidb_handler()
        return self._roidb
           

可以看到,roidb是個字典清單,每個字典包含5個鍵,boxes表示目标框,gt_overlaps是重疊資訊,gt_classes是每個box的class資訊,flipped表示該資料是否是翻轉(實際是水準鏡像)得來的,事實上還有一個鍵是seg_areas,這個我們在之前提到過,注釋部分應該是作者手誤少寫了一個。

現在我們回到tools/train_faster_rcnn_alt_opt.py檔案的get_roidb函數中去:

def get_roidb(imdb_name, rpn_file=None):
    imdb = get_imdb(imdb_name)
    print 'Loaded dataset `{:s}` for training'.format(imdb.name)
    imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
    print 'Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD)
    if rpn_file is not None:
        imdb.config['rpn_file'] = rpn_file
    roidb = get_training_roidb(imdb)
    return roidb, imdb
           

通過前面的分析我們知道imdb變量實際上是pascal_voc類的執行個體,而pascal_voc又是imdb類的子類,是以這裡的imdb變量既可以使用pascal_voc類的方法變量又可以使用imdb類的方法變量,這是個需要注意的細節。

然後讓我們從get_imdb函數之後繼續,imdb.set_proposal_method執行後輸出資訊為

Set proposal method: gt

,也就是設定了imdb類的_roidb_handler成員變量,這個可以暫時不用管。然後因為rpn_file為None,後面的if語句自然就跳過了,是以我們直接看get_training_roidb函數,這個函數的代碼在lib/fast_rcnn/train.py檔案中:

def get_training_roidb(imdb):
    """Returns a roidb (Region of Interest database) for use in training."""
    if cfg.TRAIN.USE_FLIPPED:
        print 'Appending horizontally-flipped training examples...'
        imdb.append_flipped_images()
        print 'done'

    print 'Preparing training data...'
    rdl_roidb.prepare_roidb(imdb)
    print 'done'

    return imdb.roidb
           

從函數注釋我們知道了roidb實際上是Region of Interest database的縮寫,而imdb是Image database的縮寫(get_imdb函數的注釋中有說),可以猜測roidb才是實際訓練的重點,這個我們後面也會提到。現在我們看一下該函數的具體實作,if語句中對資料做了水準鏡像,調用了append_flipped_images函數,它的代碼在lib/datasets/imdb.py檔案中:

def append_flipped_images(self):
        num_images = self.num_images
        widths = self._get_widths()
        for i in xrange(num_images):
            boxes = self.roidb[i]['boxes'].copy()
            oldx1 = boxes[:, 0].copy()
            oldx2 = boxes[:, 2].copy()
            boxes[:, 0] = widths[i] - oldx2 - 1
            boxes[:, 2] = widths[i] - oldx1 - 1
            assert (boxes[:, 2] >= boxes[:, 0]).all()
            entry = {'boxes' : boxes,
                     'gt_overlaps' : self.roidb[i]['gt_overlaps'],
                     'gt_classes' : self.roidb[i]['gt_classes'],
                     'flipped' : True}
            self.roidb.append(entry)
        self._image_index = self._image_index * 2
           

可以看到這個函數是把目标框的位置做了水準翻轉,并将翻轉過來的資料的flipped成員設定為True,這樣的一種做法應該主要是為了增大資料量,防止過拟合,提升模型泛化性用的。這裡還沒有實際處理圖檔,不要着急,我們回到get_training_roidb函數中繼續看代碼。現在到了prepare_roidb函數,它的實作在lib/roi_data_layer/roidb.py中:

def prepare_roidb(imdb):
    """Enrich the imdb's roidb by adding some derived quantities that
    are useful for training. This function precomputes the maximum
    overlap, taken over ground-truth boxes, between each ROI and
    each ground-truth box. The class with maximum overlap is also
    recorded.
    """
    sizes = [PIL.Image.open(imdb.image_path_at(i)).size
             for i in xrange(imdb.num_images)]
    roidb = imdb.roidb
    for i in xrange(len(imdb.image_index)):
        roidb[i]['image'] = imdb.image_path_at(i)
        roidb[i]['width'] = sizes[i][0]
        roidb[i]['height'] = sizes[i][1]
        # need gt_overlaps as a dense array for argmax
        gt_overlaps = roidb[i]['gt_overlaps'].toarray()
        # max overlap with gt over classes (columns)
        max_overlaps = gt_overlaps.max(axis=1)
        # gt class that had the max overlap
        max_classes = gt_overlaps.argmax(axis=1)
        roidb[i]['max_classes'] = max_classes
        roidb[i]['max_overlaps'] = max_overlaps
        # sanity checks
        # max overlap of 0 => class should be zero (background)
        zero_inds = np.where(max_overlaps == 0)[0]
        assert all(max_classes[zero_inds] == 0)
        # max overlap > 0 => class should not be zero (must be a fg class)
        nonzero_inds = np.where(max_overlaps > 0)[0]
        assert all(max_classes[nonzero_inds] != 0)
           

這個函數實際上隻是對roidb進一步設定,添加了5個鍵

image

width

height

max_classes

max_overlaps

。需要注意的是

image

鍵隻是存放了圖檔的路徑,并沒有實際讀取圖檔。

然後函數層層傳回到tools/train_faster_rcnn_alt_opt.py的get_roidb函數,然後又回到了train_rpn函數,我們繼續看這個函數,

output_dir = get_output_dir(imdb)

隻是設定了輸出路徑,不必深究,然後往下看train_net函數,它的輸入是roidb,難道是在這裡對roidb做了什麼處理?

我們進入這個函數,它的代碼在lib/fast_rcnn/train.py最後:

def train_net(solver_prototxt, roidb, output_dir,
              pretrained_model=None, max_iters=40000):
    """Train a Fast R-CNN network."""

    roidb = filter_roidb(roidb)
    sw = SolverWrapper(solver_prototxt, roidb, output_dir,
                       pretrained_model=pretrained_model)

    print 'Solving...'
    model_paths = sw.train_model(max_iters)
    print 'done solving'
    return model_paths
           

這裡調用了filter_roidb函數,代碼就在train_net函數的上面:

def filter_roidb(roidb):
    """Remove roidb entries that have no usable RoIs."""

    def is_valid(entry):
        # Valid images have:
        #   (1) At least one foreground RoI OR
        #   (2) At least one background RoI
        overlaps = entry['max_overlaps']
        # find boxes with sufficient overlap
        fg_inds = np.where(overlaps >= cfg.TRAIN.FG_THRESH)[0]
        # Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)
        bg_inds = np.where((overlaps < cfg.TRAIN.BG_THRESH_HI) &
                           (overlaps >= cfg.TRAIN.BG_THRESH_LO))[0]
        # image is only valid if such boxes exist
        valid = len(fg_inds) > 0 or len(bg_inds) > 0
        return valid

    num = len(roidb)
    filtered_roidb = [entry for entry in roidb if is_valid(entry)]
    num_after = len(filtered_roidb)
    print 'Filtered {} roidb entries: {} -> {}'.format(num - num_after,
                                                       num, num_after)
    return filtered_roidb
           

看了下代碼,發現它隻是對roidb做了篩選,不禁讓人大失所望,搞到現在似乎隻是在操作xml檔案,那我們的圖像到底在哪裡讀的呢?我們回到train_net函數,看這行代碼:

這裡都訓練了你該讀取圖檔了吧,我們跟進去,這個函數也在train.py中:

def train_model(self, max_iters):
        """Network training loop."""
        last_snapshot_iter = -1
        timer = Timer()
        model_paths = []
        while self.solver.iter < max_iters:
            # Make one SGD update
            timer.tic()
            self.solver.step(1)
            timer.toc()
            if self.solver.iter % (10 * self.solver_param.display) == 0:
                print 'speed: {:.3f}s / iter'.format(timer.average_time)

            if self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = self.solver.iter
                model_paths.append(self.snapshot())

        if last_snapshot_iter != self.solver.iter:
            model_paths.append(self.snapshot())
        return model_paths
           

看到

self.solver.step(1)

這一句,這裡就是訓練的一次疊代了,這個solver是caffe的SGDSolver,這裡涉及到caffe的代碼,我們就不跟下去了。

參考這篇文章faster-rcnn 之訓練資料是如何準備的:imdb和roidb的産生(caffe版本),可以知道圖檔資料是在神經網絡資料層前向傳播時讀取的,調用了lib/roi_data_layer/layer.py中RoIDataLayer類的forward方法,繼而調用_get_next_minibatch方法,然後又調用了lib/roi_data_layer/minibatch.py中的get_minibatch函數,最後調用了_get_image_blob函數,這裡的一行代碼

im = cv2.imread(roidb[i]['image'])

才真正開始讀取圖檔,然後再對圖檔做水準鏡像等其他處理。

到這裡,我們已經分析清了roidb和imdb的關系,而且可以看到roidb是訓練時主要用到的對象。文章開頭的bug也最終通過檢查,發現是自己資料的問題而解決了。雖然沒有和大家一起把圖像的其他處理看完,但還是收獲不小的。

繼續閱讀