天天看点

py-faster-rcnn源码解析之处理训练数据

因为最近在使用py-faster-rcnn训练自己的数据时报如下错:

roidb[i]['image'] = imdb.image_path_at(i) 
IndexError: list index out of range 
           

看了网上的很多说法都是让删除py-faster-rcnn/data/cache下的pkl文件,但是该方法对我并没有起作用,于是就将py-faster-rcnn处理训练数据部分的代码跟踪了一下,这里和大家一起分享,也做个记录。

下面的解说都是以py-faster-rcnn目录为根目录,后面就不再重复了。

我是通过执行scripts/faster_rcnn_alt_opt.sh来训练模型的,从该脚本的第46行代码:

time ./tools/train_faster_rcnn_alt_opt.py --gpu ${GPU_ID} \
           

我们可以知道模型是通过tools/train_faster_rcnn_alt_opt.py进行训练的,接下里我们就去看这个py文件的源码。

模型训练分为两个Stage,每个Stage都是从RPN训练开始的,所以我们直接看train_rpn函数:

def train_rpn(queue=None, imdb_name=None, init_model=None, solver=None,
              max_iters=None, cfg=None):
    """Train a Region Proposal Network in a separate training process.
    """

    # Not using any proposals, just ground-truth boxes
    cfg.TRAIN.HAS_RPN = True
    cfg.TRAIN.BBOX_REG = False  # applies only to Fast R-CNN bbox regression
    cfg.TRAIN.PROPOSAL_METHOD = 'gt'
    cfg.TRAIN.IMS_PER_BATCH = 1
    print 'Init model: {}'.format(init_model)
    print('Using config:')
    pprint.pprint(cfg)

    import caffe
    _init_caffe(cfg)

    roidb, imdb = get_roidb(imdb_name)
    print 'roidb len: {}'.format(len(roidb))
    output_dir = get_output_dir(imdb)
    print 'Output will be saved to `{:s}`'.format(output_dir)

    model_paths = train_net(solver, roidb, output_dir,
                            pretrained_model=init_model,
                            max_iters=max_iters)
    # Cleanup all but the final model
    for i in model_paths[:-1]:
        os.remove(i)
    rpn_model_path = model_paths[-1]
    # Send final model path through the multiprocessing queue
    queue.put({'model_path': rpn_model_path})
           

前面的几行代码是进行训练的配置,一直到这几行代码开始准备数据:

roidb, imdb = get_roidb(imdb_name)
    print 'roidb len: {}'.format(len(roidb))
    output_dir = get_output_dir(imdb)
    print 'Output will be saved to `{:s}`'.format(output_dir)
           

所以我们再跳去get_roidb函数去看它是如何实现的:

def get_roidb(imdb_name, rpn_file=None):
    imdb = get_imdb(imdb_name)
    print 'Loaded dataset `{:s}` for training'.format(imdb.name)
    imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
    print 'Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD)
    if rpn_file is not None:
        imdb.config['rpn_file'] = rpn_file
    roidb = get_training_roidb(imdb)
    return roidb, imdb
           

Stage1 RPN, init from ImageNet model

时输入参数imdb_name是voc_2007_trainval,rpn_file是None。从这个函数我们能够得到的信息是roidb是与imdb相关的,下面我们先看imdb是怎么得到的,即先看get_imdb函数,这个函数的代码在lib/datasets/factory.py中:

def get_imdb(name):
    """Get an imdb (image database) by name."""
    if not __sets.has_key(name):
        raise KeyError('Unknown dataset: {}'.format(name))
    return __sets[name]()
           

文件的开头处代码

__sets={}

将__sets定义为字典,而由get_imdb函数的返回值我们可以看到该字典的key是imdb名称,而value是个匿名函数,因为我们使用的voc_2007_trainval的数据,所以往上看这段代码:

# Set up voc_<year>_<split> using selective search "fast" mode
for year in ['2007', '2012']:
    for split in ['train', 'val', 'trainval', 'test']:
        name = 'voc_{}_{}'.format(year, split)
        __sets[name] = (lambda split=split, year=year: pascal_voc(split, year))
           

可以看到上面提到的匿名函数实际上是调用了pascal_voc(split, year),然后我们进入lib/datasets/pascal_voc.py,发现它实际上是pascal_voc的构造函数,这个类继承了imdb类。

class pascal_voc(imdb):
    def __init__(self, image_set, year, devkit_path=None):
        imdb.__init__(self, 'voc_' + year + '_' + image_set)
        self._year = year
        self._image_set = image_set
        self._devkit_path = self._get_default_path() if devkit_path is None \
                            else devkit_path
        self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)
        self._classes = ('__background__', # always index 0
                         'aeroplane', 'bicycle', 'bird', 'boat',
                         'bottle', 'bus', 'car', 'cat', 'chair',
                         'cow', 'diningtable', 'dog', 'horse',
                         'motorbike', 'person', 'pottedplant',
                         'sheep', 'sofa', 'train', 'tvmonitor')
        self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))
        self._image_ext = '.jpg'
        self._image_index = self._load_image_set_index()
        # Default to roidb handler
        self._roidb_handler = self.selective_search_roidb
        self._salt = str(uuid.uuid4())
        self._comp_id = 'comp4'

        # PASCAL specific config options
        self.config = {'cleanup'     : True,
                       'use_salt'    : True,
                       'use_diff'    : False,
                       'matlab_eval' : False,
                       'rpn_file'    : None,
                       'min_size'    : 2}

        assert os.path.exists(self._devkit_path), \
                'VOCdevkit path does not exist: {}'.format(self._devkit_path)
        assert os.path.exists(self._data_path), \
                'Path does not exist: {}'.format(self._data_path)
           

这个代码大部分只是进行了一些配置,但有两行代码需要注意。

一行是

self._image_index = self._load_image_set_index

,我们把它打印出来,可以看到它里面是图片的名称,但去除了后缀和路径的其他部分,说明它是后面要用来加载图片和xml文件的。

还有一行是

self._roidb_handler = self.selective_search_roidb

,跳转到selective_search_roidb方法,然后一路跟踪,先是它调用了gt_roidb方法,然后它调用了_load_pascal_annotation方法,在这个方法里的第一句代码就加载了xml文件(xml文件存放了需要检测的目标所在的区域信息)。

这个函数最后得到了一个键为

boxes

gt_classes

gt_overlaps

flipped

seg_areas

的字典,后面我们还会提到。

对于pascal_voc类的分析就到这里,下面我们看一下pascal_voc的父类imdb,它的实现在lib/dataset/imdb.py中:

class imdb(object):
    """Image database."""

    def __init__(self, name):
        self._name = name
        self._num_classes = 0
        self._classes = []
        self._image_index = []
        self._obj_proposer = 'selective_search'
        self._roidb = None
        self._roidb_handler = self.default_roidb
        # Use this dict for storing dataset specific config options
        self.config = {}
           

这里,我们能得到的最主要的信息是roidb是imdb的成员变量。而roidb里包含哪些信息其实在下面的代码中也有说明:

@property
    def roidb(self):
        # A roidb is a list of dictionaries, each with the following keys:
        #   boxes
        #   gt_overlaps
        #   gt_classes
        #   flipped
        if self._roidb is not None:
            return self._roidb
        self._roidb = self.roidb_handler()
        return self._roidb
           

可以看到,roidb是个字典列表,每个字典包含5个键,boxes表示目标框,gt_overlaps是重叠信息,gt_classes是每个box的class信息,flipped表示该数据是否是翻转(实际是水平镜像)得来的,事实上还有一个键是seg_areas,这个我们在之前提到过,注释部分应该是作者手误少写了一个。

现在我们回到tools/train_faster_rcnn_alt_opt.py文件的get_roidb函数中去:

def get_roidb(imdb_name, rpn_file=None):
    imdb = get_imdb(imdb_name)
    print 'Loaded dataset `{:s}` for training'.format(imdb.name)
    imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
    print 'Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD)
    if rpn_file is not None:
        imdb.config['rpn_file'] = rpn_file
    roidb = get_training_roidb(imdb)
    return roidb, imdb
           

通过前面的分析我们知道imdb变量实际上是pascal_voc类的实例,而pascal_voc又是imdb类的子类,所以这里的imdb变量既可以使用pascal_voc类的方法变量又可以使用imdb类的方法变量,这是个需要注意的细节。

然后让我们从get_imdb函数之后继续,imdb.set_proposal_method执行后输出信息为

Set proposal method: gt

,也就是设置了imdb类的_roidb_handler成员变量,这个可以暂时不用管。然后因为rpn_file为None,后面的if语句自然就跳过了,所以我们直接看get_training_roidb函数,这个函数的代码在lib/fast_rcnn/train.py文件中:

def get_training_roidb(imdb):
    """Returns a roidb (Region of Interest database) for use in training."""
    if cfg.TRAIN.USE_FLIPPED:
        print 'Appending horizontally-flipped training examples...'
        imdb.append_flipped_images()
        print 'done'

    print 'Preparing training data...'
    rdl_roidb.prepare_roidb(imdb)
    print 'done'

    return imdb.roidb
           

从函数注释我们知道了roidb实际上是Region of Interest database的缩写,而imdb是Image database的缩写(get_imdb函数的注释中有说),可以猜测roidb才是实际训练的重点,这个我们后面也会提到。现在我们看一下该函数的具体实现,if语句中对数据做了水平镜像,调用了append_flipped_images函数,它的代码在lib/datasets/imdb.py文件中:

def append_flipped_images(self):
        num_images = self.num_images
        widths = self._get_widths()
        for i in xrange(num_images):
            boxes = self.roidb[i]['boxes'].copy()
            oldx1 = boxes[:, 0].copy()
            oldx2 = boxes[:, 2].copy()
            boxes[:, 0] = widths[i] - oldx2 - 1
            boxes[:, 2] = widths[i] - oldx1 - 1
            assert (boxes[:, 2] >= boxes[:, 0]).all()
            entry = {'boxes' : boxes,
                     'gt_overlaps' : self.roidb[i]['gt_overlaps'],
                     'gt_classes' : self.roidb[i]['gt_classes'],
                     'flipped' : True}
            self.roidb.append(entry)
        self._image_index = self._image_index * 2
           

可以看到这个函数是把目标框的位置做了水平翻转,并将翻转过来的数据的flipped成员设置为True,这样的一种做法应该主要是为了增大数据量,防止过拟合,提升模型泛化性用的。这里还没有实际处理图片,不要着急,我们回到get_training_roidb函数中继续看代码。现在到了prepare_roidb函数,它的实现在lib/roi_data_layer/roidb.py中:

def prepare_roidb(imdb):
    """Enrich the imdb's roidb by adding some derived quantities that
    are useful for training. This function precomputes the maximum
    overlap, taken over ground-truth boxes, between each ROI and
    each ground-truth box. The class with maximum overlap is also
    recorded.
    """
    sizes = [PIL.Image.open(imdb.image_path_at(i)).size
             for i in xrange(imdb.num_images)]
    roidb = imdb.roidb
    for i in xrange(len(imdb.image_index)):
        roidb[i]['image'] = imdb.image_path_at(i)
        roidb[i]['width'] = sizes[i][0]
        roidb[i]['height'] = sizes[i][1]
        # need gt_overlaps as a dense array for argmax
        gt_overlaps = roidb[i]['gt_overlaps'].toarray()
        # max overlap with gt over classes (columns)
        max_overlaps = gt_overlaps.max(axis=1)
        # gt class that had the max overlap
        max_classes = gt_overlaps.argmax(axis=1)
        roidb[i]['max_classes'] = max_classes
        roidb[i]['max_overlaps'] = max_overlaps
        # sanity checks
        # max overlap of 0 => class should be zero (background)
        zero_inds = np.where(max_overlaps == 0)[0]
        assert all(max_classes[zero_inds] == 0)
        # max overlap > 0 => class should not be zero (must be a fg class)
        nonzero_inds = np.where(max_overlaps > 0)[0]
        assert all(max_classes[nonzero_inds] != 0)
           

这个函数实际上只是对roidb进一步设置,添加了5个键

image

width

height

max_classes

max_overlaps

。需要注意的是

image

键只是存放了图片的路径,并没有实际读取图片。

然后函数层层返回到tools/train_faster_rcnn_alt_opt.py的get_roidb函数,然后又回到了train_rpn函数,我们继续看这个函数,

output_dir = get_output_dir(imdb)

只是设置了输出路径,不必深究,然后往下看train_net函数,它的输入是roidb,难道是在这里对roidb做了什么处理?

我们进入这个函数,它的代码在lib/fast_rcnn/train.py最后:

def train_net(solver_prototxt, roidb, output_dir,
              pretrained_model=None, max_iters=40000):
    """Train a Fast R-CNN network."""

    roidb = filter_roidb(roidb)
    sw = SolverWrapper(solver_prototxt, roidb, output_dir,
                       pretrained_model=pretrained_model)

    print 'Solving...'
    model_paths = sw.train_model(max_iters)
    print 'done solving'
    return model_paths
           

这里调用了filter_roidb函数,代码就在train_net函数的上面:

def filter_roidb(roidb):
    """Remove roidb entries that have no usable RoIs."""

    def is_valid(entry):
        # Valid images have:
        #   (1) At least one foreground RoI OR
        #   (2) At least one background RoI
        overlaps = entry['max_overlaps']
        # find boxes with sufficient overlap
        fg_inds = np.where(overlaps >= cfg.TRAIN.FG_THRESH)[0]
        # Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)
        bg_inds = np.where((overlaps < cfg.TRAIN.BG_THRESH_HI) &
                           (overlaps >= cfg.TRAIN.BG_THRESH_LO))[0]
        # image is only valid if such boxes exist
        valid = len(fg_inds) > 0 or len(bg_inds) > 0
        return valid

    num = len(roidb)
    filtered_roidb = [entry for entry in roidb if is_valid(entry)]
    num_after = len(filtered_roidb)
    print 'Filtered {} roidb entries: {} -> {}'.format(num - num_after,
                                                       num, num_after)
    return filtered_roidb
           

看了下代码,发现它只是对roidb做了筛选,不禁让人大失所望,搞到现在似乎只是在操作xml文件,那我们的图像到底在哪里读的呢?我们回到train_net函数,看这行代码:

这里都训练了你该读取图片了吧,我们跟进去,这个函数也在train.py中:

def train_model(self, max_iters):
        """Network training loop."""
        last_snapshot_iter = -1
        timer = Timer()
        model_paths = []
        while self.solver.iter < max_iters:
            # Make one SGD update
            timer.tic()
            self.solver.step(1)
            timer.toc()
            if self.solver.iter % (10 * self.solver_param.display) == 0:
                print 'speed: {:.3f}s / iter'.format(timer.average_time)

            if self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = self.solver.iter
                model_paths.append(self.snapshot())

        if last_snapshot_iter != self.solver.iter:
            model_paths.append(self.snapshot())
        return model_paths
           

看到

self.solver.step(1)

这一句,这里就是训练的一次迭代了,这个solver是caffe的SGDSolver,这里涉及到caffe的代码,我们就不跟下去了。

参考这篇文章faster-rcnn 之训练数据是如何准备的:imdb和roidb的产生(caffe版本),可以知道图片数据是在神经网络数据层前向传播时读取的,调用了lib/roi_data_layer/layer.py中RoIDataLayer类的forward方法,继而调用_get_next_minibatch方法,然后又调用了lib/roi_data_layer/minibatch.py中的get_minibatch函数,最后调用了_get_image_blob函数,这里的一行代码

im = cv2.imread(roidb[i]['image'])

才真正开始读取图片,然后再对图片做水平镜像等其他处理。

到这里,我们已经分析清了roidb和imdb的关系,而且可以看到roidb是训练时主要用到的对象。文章开头的bug也最终通过检查,发现是自己数据的问题而解决了。虽然没有和大家一起把图像的其他处理看完,但还是收获不小的。

继续阅读