天天看點

py-faster-rcnn詳解(3)——train.py接口說明 snapshot train_model get_training_roidb filter_roidb train_net

主要内容是一個solver包裝類,主要目的為了實作自己的snapshot。

class SolverWrapper(object):
"""A simple wrapper around Caffe's solver.
This wrapper gives us control over he snapshotting process, which we
use to unnormalize the learned bounding-box regression weights.
"""

def __init__(self, solver_prototxt, roidb, output_dir,
pretrained_model=None):
    """Initialize the SolverWrapper."""
    self.output_dir = output_dir


    if (cfg.TRAIN.HAS_RPN and cfg.TRAIN.BBOX_REG and
    cfg.TRAIN.BBOX_NORMALIZE_TARGETS):

    # RPN can only use precomputed normalization because there are no
    # fixed statistics to compute a priori
    assert cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED

    if cfg.TRAIN.BBOX_REG:
    print 'Computing bounding-box regression targets...'
    self.bbox_means, self.bbox_stds = \
    rdl_roidb.add_bbox_regression_targets(roidb)
    print 'done'

    self.solver = caffe.SGDSolver(solver_prototxt)
    if pretrained_model is not None:
    print ('Loading pretrained model '
    'weights from {:s}').format(pretrained_model)
    self.solver.net.copy_from(pretrained_model)

    self.solver_param = caffe_pb2.SolverParameter()
    with open(solver_prototxt, 'rt') as f:
    pb2.text_format.Merge(f.read(), self.solver_param)

    #将roidb設定到net中。
    self.solver.net.layers[].set_roidb(roidb)           
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37

snapshot

實作自己的snapshot。

def snapshot(self):
    """Take a snapshot of the network after unnormalizing the learned
    bounding-box regression weights. This enables easy use at test-time.
    """
    net = self.solver.net

    scale_bbox_params = (cfg.TRAIN.BBOX_REG and
    cfg.TRAIN.BBOX_NORMALIZE_TARGETS and
    net.params.has_key('bbox_pred'))

    if scale_bbox_params:
    # save original values
    orig_0 = net.params['bbox_pred'][].data.copy()
    orig_1 = net.params['bbox_pred'][].data.copy()

    # scale and shift with bbox reg unnormalization; then save snapshot
    net.params['bbox_pred'][].data[...] = \
    (net.params['bbox_pred'][].data *
    self.bbox_stds[:, np.newaxis])
    net.params['bbox_pred'][].data[...] = \
    (net.params['bbox_pred'][].data *
    self.bbox_stds + self.bbox_means)

    infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX
    if cfg.TRAIN.SNAPSHOT_INFIX != else )
    filename = (self.solver_param.snapshot_prefix + infix +
    '_iter_{:d}'.format(self.solver.iter) + '.caffemodel')
    filename = os.path.join(self.output_dir, filename)

    net.save(str(filename))
    print 'Wrote snapshot to: {:s}'.format(filename)

    if scale_bbox_params:
    # restore net to original state
    net.params['bbox_pred'][].data[...] = orig_0
    net.params['bbox_pred'][].data[...] = orig_1
    return filename           
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37

train_model

訓練主流程,并控制了snapshot的過程。

def train_model(self, max_iters):
    """Network training loop."""
    last_snapshot_iter = -
    timer = Timer()
    model_paths = []
    while self.solver.iter < max_iters:
    # Make one SGD update
    timer.tic()
    self.solver.step()  #啟動計算過程,SGD梯度下降
    timer.toc()
    if self.solver.iter % ( * self.solver_param.display) == :
    print 'speed: {:.3f}s / iter'.format(timer.average_time)

    if self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == :
    last_snapshot_iter = self.solver.iter
    model_paths.append(self.snapshot())

    if last_snapshot_iter != self.solver.iter:
    model_paths.append(self.snapshot())
    return model_paths           
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

get_training_roidb

 将roidb中的每張圖檔水準翻轉,并添加回去,以及調用prepare_roidb做了些準備性的工作。

def get_training_roidb(imdb):
    """Returns a roidb (Region of Interest database) for use in training."""
    if cfg.TRAIN.USE_FLIPPED:
    print 'Appending horizontally-flipped training examples...'
    imdb.append_flipped_images()
    print 'done'

    print 'Preparing training data...'
    rdl_roidb.prepare_roidb(imdb)
    print 'done'

    return imdb.roidb           
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

filter_roidb

該函數中定義了一個is_valid函數,用于判斷roidb中的每個entry是否至少有一個前景box或背景box。 

roidb全是groudtruth時,因為box與對應的類的重合度(overlaps)顯然為1,也就是說roidb起碼要有一個标記類。 

如果roidb包含了一些proposal,overlaps在[BG_THRESH_LO, BG_THRESH_HI]之間的都将被認為是背景,大于FG_THRESH才被認為是前景,roidb 至少要有一個前景或背景,否則将被過濾掉。 

def filter_roidb(roidb):
"""Remove roidb entries that have no usable RoIs."""

    def is_valid(entry):
        # Valid images have:
        #   (1) At least one foreground RoI OR
        #   (2) At least one background RoI
        overlaps = entry['max_overlaps']
        # find boxes with sufficient overlap
        fg_inds = np.where(overlaps >= cfg.TRAIN.FG_THRESH)[]
        # Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)
        bg_inds = np.where((overlaps < cfg.TRAIN.BG_THRESH_HI) &
        (overlaps >= cfg.TRAIN.BG_THRESH_LO))[]
        # image is only valid if such boxes exist
        valid = len(fg_inds) >  or len(bg_inds) > 
    return valid

    num = len(roidb)
    filtered_roidb = [entry for entry in roidb if is_valid(entry)]
    num_after = len(filtered_roidb)
    print 'Filtered {} roidb entries: {} -> {}'.format(num - num_after,
    num, num_after)
    return filtered_roidb           
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

train_net

進行網絡的訓練。

def train_net(solver_prototxt, roidb, output_dir,
pretrained_model=None, max_iters=):
    """Train a Fast R-CNN network."""

    roidb = filter_roidb(roidb)
    sw = SolverWrapper(solver_prototxt, roidb, output_dir,
    pretrained_model=pretrained_model)

    print 'Solving...'
    model_paths = sw.train_model(max_iters)
    print 'done solving'
    return model_paths
           

繼續閱讀