天天看點

EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector

EAST: An Efficient and Accurate Scene Text Detector

用于場景文本檢測的先前方法已經在各種基準測試中獲得了良好的性能。然而,在處理具有挑戰性的場景時,即使配備了深度神經網絡模型,通常也會達不到很好性能,因為整體性能取決于pipline中多個階段群組件的互相作用。EAST提出了一個簡單而強大的pipline,可以在自然場景中産生快速準确的文本檢測。算法流程直接預測完整圖像中任意方向和四邊形形狀的單詞或文本行,消除了使用單個神經網絡的不必要的中間步驟(例如,候選聚合和字分區)。在标準資料集(包括ICDAR 2015,COCO-Text和MSRA-TD500)的實驗表明,所提出的算法在準确性和效率方面明顯優于最先進的方法。在ICDAR 2015資料集上,所提出的算法在720p分辨率下以13.2fps達到0.7820的F-score。

算法在ICDAR 2015 [15](在多尺度下測試時為0.8072),在MSRA-TD500 [40]上為0.7608,在COCO-Text上為0.3945 [36]時,得分為0.7820,優于之前的狀态 - 性能最先進的算法,同時平均花費的時間少得多(在Titan-X GPU上,對于最好的模型,在720p分辨率下為13.2fps,對于我們最快的模型,為16.8fps)。

主要創新點:

  1. 提出了一種場景文本檢測方法,包括兩個階段:全卷積網絡和NMS合并階段。 FCN直接生成文本區域,不包括備援和耗時的中間步驟。
  2. 算法可以靈活地生成字級或線級預測,其幾何形狀可以是旋轉框或四邊形,具體取決于具體應用。
  3. 所提出的算法在精度和速度方面明顯優于最先進的方法。
EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector

模型架構:

EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector

預測通道之一是得分圖,其像素值在[0,1]的範圍内。 其餘通道表示從每個像素的視圖中包圍該單詞的幾率。 分數代表在相同位置預測的幾何形狀的置信度。

我們已經為文本區域,旋轉框(RBOX)和四邊形(QUAD)實驗了兩種幾何形狀,并為每種幾何設計了不同的損失函數。 然後将門檻值處理應用于每個預測區域,其中得分超過預定門檻值的幾何被認為是有效的并且被儲存用于稍後的非最大抑制。 NMS之後的結果被認為是管道的最終輸出。

def resnet_east(backbone='resnet50', inputs=None, modifier=None, **kwargs):
    # choose default input
    if inputs is None:
        if keras.backend.image_data_format() == 'channels_first':
            inputs = keras.layers.Input(shape=(3, None, None))
        else:
            inputs = keras.layers.Input(shape=(None, None, 3))

    # create the vgg backbone
    if backbone == 'resnet50':
        resnet = keras.applications.ResNet50(input_tensor=inputs, include_top=False, weights=None)
    elif backbone == 'resnet101':
        resnet = keras.applications.ResNet101(input_tensor=inputs, include_top=False, weights=None)
    elif backbone == 'resnet152':
        resnet = keras.applications.ResNet152(input_tensor=inputs, include_top=False, weights=None)
    else:
        raise ValueError("Backbone '{}' not recognized.".format(backbone))

    if modifier:
        resnet = modifier(resnet)

    layer_names = ['activation_49', 'activation_40', 'activation_22', 'activation_10']
    backbone_layers = [resnet.get_layer(i).output for i in layer_names]

    return east(input=inputs, backbone_layers=backbone_layers, **kwargs)

def east(
    input,
    backbone_layers,
    config = None,
    name = 'east'
):
    overly_small_text_region_training_mask = Input(shape=(None, None, 1), name='overly_small_text_region_training_mask')
    text_region_boundary_training_mask = Input(shape=(None, None, 1), name='text_region_boundary_training_mask')
    target_score_map = Input(shape=(None, None, 1), name='target_score_map')

    act_49, act_40, act_22, act_10 = backbone_layers

    if config is None:
        config = cfg

    x = Lambda(resize_bilinear, name='resize_1')(act_49)
    x = concatenate([x, act_40], axis=3)
    x = Conv2D(128, (1, 1), padding='same', kernel_regularizer=l2(1e-5))(x)
    x = BatchNormalization(momentum=0.997, epsilon=1e-5, scale=True)(x)
    x = Activation('relu')(x)
    x = Conv2D(128, (3, 3), padding='same', kernel_regularizer=l2(1e-5))(x)
    x = BatchNormalization(momentum=0.997, epsilon=1e-5, scale=True)(x)
    x = Activation('relu')(x)

    x = Lambda(resize_bilinear, name='resize_2')(x)
    x = concatenate([x, act_22], axis=3)
    x = Conv2D(64, (1, 1), padding='same', kernel_regularizer=l2(1e-5))(x)
    x = BatchNormalization(momentum=0.997, epsilon=1e-5, scale=True)(x)
    x = Activation('relu')(x)
    x = Conv2D(64, (3, 3), padding='same', kernel_regularizer=l2(1e-5))(x)
    x = BatchNormalization(momentum=0.997, epsilon=1e-5, scale=True)(x)
    x = Activation('relu')(x)

    x = Lambda(resize_bilinear, name='resize_3')(x)
    x = concatenate([x, act_10], axis=3)
    x = Conv2D(32, (1, 1), padding='same', kernel_regularizer=l2(1e-5))(x)
    x = BatchNormalization(momentum=0.997, epsilon=1e-5, scale=True)(x)
    x = Activation('relu')(x)
    x = Conv2D(32, (3, 3), padding='same', kernel_regularizer=l2(1e-5))(x)
    x = BatchNormalization(momentum=0.997, epsilon=1e-5, scale=True)(x)
    x = Activation('relu')(x)

    x = Conv2D(32, (3, 3), padding='same', kernel_regularizer=l2(1e-5))(x)
    x = BatchNormalization(momentum=0.997, epsilon=1e-5, scale=True)(x)
    x = Activation('relu')(x)

    pred_score_map = Conv2D(1, (1, 1), activation=keras.backend.sigmoid, name='pred_score_map')(x)
    rbox_geo_map = Conv2D(4, (1, 1), activation=keras.backend.sigmoid, name='rbox_geo_map')(x)
    rbox_geo_map = Lambda(lambda x: x * config.INPUT_SIZE)(rbox_geo_map)
    angle_map = Conv2D(1, (1, 1), activation=keras.backend.sigmoid, name='rbox_angle_map')(x)
    angle_map = Lambda(lambda x: (x - 0.5) * np.pi / 2)(angle_map)
    pred_geo_map = concatenate([rbox_geo_map, angle_map], axis=3, name='pred_geo_map')

    model = keras.models.Model(inputs=[input, overly_small_text_region_training_mask,
                                       text_region_boundary_training_mask,
                                       target_score_map],
                               outputs=[pred_score_map, pred_geo_map],
                               name=name)
    return model
           

幾何輸出可以是RBOX或QUAD之一,在Tab中彙總.1

EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector

對于RBOX,幾何形狀由4個通道的軸對齊邊界框(AABB)R和1個通道旋轉角θ表示。R的公式與[9]中的公式相同,其中4個通道分别表示從像素位置到矩形的頂部,右側,底部,左邊界的4個距離。

對于QUAD Q,我們使用8個數字來表示從四個角頂點{p i |的坐标偏移 i∈{1,4,3,4}}四邊形到像素位置。由于每個距離偏移包含兩個數字(Δxi,Δyi),是以幾何輸出包含8個通道。

上述代碼為RBOX

前處理:

隻考慮幾何是四邊形的情況。分數圖上的四邊形的正面積被設計為大緻縮小(shrunk) 的原始面積,如圖所示。

EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector

對于四邊形

EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector

其中p_i= {x i,y i}是四邊形上的頂點,以順時針順序排列。為了縮小Q,我們首先計算每個頂點pi的參考長度ri

EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector
EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector

首先縮小四邊形的兩個較長邊,然後縮短兩個較短邊。對于每對兩個相對的邊,我們通過比較它們的長度的平均值來确定“更長”的對。

對于每個邊

EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector

通過将兩個端點沿着邊緣向内移動0.3ri和

EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector

來縮小它。

def compute_targets_east(self, imggroup, anngroup):
        score_maps = []
        geo_maps = []
        overly_small_text_region_training_masks = []
        text_region_boundary_training_masks = []
        for img, ann in zip(imggroup, anngroup):
            gtbox = ann['bboxes'].reshape(((-1, 4, 2)))
            label = ann['labels']
            h, w, c = img.shape

            score_map, geo_map, overly_small_text_region_training_mask, text_region_boundary_training_mask = generate_rbox(
                (h, w), gtbox, label, self.config)
            score_maps.append(score_map[::4, ::4, np.newaxis].astype(np.float32))
            geo_maps.append(geo_map[::4, ::4, :].astype(np.float32))
            overly_small_text_region_training_masks.append(
                overly_small_text_region_training_mask[::4, ::4, np.newaxis].astype(np.float32))
            text_region_boundary_training_masks.append(
                text_region_boundary_training_mask[::4, ::4, np.newaxis].astype(np.float32))

        return [np.array(overly_small_text_region_training_masks),
                np.array(text_region_boundary_training_masks),
                np.array(score_maps)], \
               [np.array(score_maps), np.array(geo_maps)]
def generate_rbox(im_size, polys, tags, config = None):
    if config is None:
        config = cfg
    h, w = im_size
    shrinked_poly_mask = np.zeros((h, w), dtype=np.uint8)
    orig_poly_mask = np.zeros((h, w), dtype=np.uint8)
    score_map = np.zeros((h, w), dtype=np.uint8)
    geo_map = np.zeros((h, w, 5), dtype=np.float32)
    # mask used during traning, to ignore some hard areas
    overly_small_text_region_training_mask = np.ones((h, w), dtype=np.uint8)
    for poly_idx, poly_data in enumerate(zip(polys, tags)):
        poly = poly_data[0]
        tag = poly_data[1]
        #确定短邊距ri
        r = [None, None, None, None]
        for i in range(4):
            r[i] = min(np.linalg.norm(poly[i] - poly[(i + 1) % 4]),
                       np.linalg.norm(poly[i] - poly[(i - 1) % 4]))
        # 四邊形縮減
        shrinked_poly = shrink_poly(poly.copy(), r).astype(np.int32)[np.newaxis, :, :]
        # 填充得分圖,
        cv2.fillPoly(score_map, shrinked_poly, 1)
        cv2.fillPoly(shrinked_poly_mask, shrinked_poly, poly_idx + 1)
        # 填充原四邊形
        cv2.fillPoly(orig_poly_mask, poly.astype(np.int32)[np.newaxis, :, :], 1)
        # 四邊形太小則忽略
        poly_h = min(np.linalg.norm(poly[0] - poly[3]), np.linalg.norm(poly[1] - poly[2]))
        poly_w = min(np.linalg.norm(poly[0] - poly[1]), np.linalg.norm(poly[2] - poly[3]))
        if min(poly_h, poly_w) < config.min_text_size:
            cv2.fillPoly(overly_small_text_region_training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0)
        if tag:
            cv2.fillPoly(overly_small_text_region_training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0)

        xy_in_poly = np.argwhere(shrinked_poly_mask == (poly_idx + 1))
        # if geometry == 'RBOX':
        # generate a parallelogram for any combination of two vertices
        fitted_parallelograms = []
        for i in range(4):
            p0 = poly[i]
            p1 = poly[(i + 1) % 4]
            p2 = poly[(i + 2) % 4]
            p3 = poly[(i + 3) % 4]
            edge = fit_line([p0[0], p1[0]], [p0[1], p1[1]]) #[k, -1, b], [1., 0., -p1[0]] if p1[0] == p1[1]:
            backward_edge = fit_line([p0[0], p3[0]], [p0[1], p3[1]])
            forward_edge = fit_line([p1[0], p2[0]], [p1[1], p2[1]])
            #p2 到p0-p1的距離》p3 到p0-p1的距離
            if point_dist_to_line(p0, p1, p2) > point_dist_to_line(p0, p1, p3):
                # 經過p2平行于p0-p1的直線
                if edge[1] == 0:
                    edge_opposite = [1, 0, -p2[0]]
                else:
                    edge_opposite = [edge[0], -1, p2[1] - edge[0] * p2[0]]
            else:
                # # 經過p3平行于p0-p1的直線
                if edge[1] == 0:
                    edge_opposite = [1, 0, -p3[0]]
                else:
                    edge_opposite = [edge[0], -1, p3[1] - edge[0] * p3[0]]
            # move forward edge
            new_p0 = p0
            new_p1 = p1
            new_p2 = p2
            new_p3 = p3
            new_p2 = line_cross_point(forward_edge, edge_opposite) #p2-p3直線與p1-p2的交點
            if point_dist_to_line(p1, new_p2, p0) > point_dist_to_line(p1, new_p2, p3):
                # across p0
                if forward_edge[1] == 0:
                    forward_opposite = [1, 0, -p0[0]]
                else:
                    forward_opposite = [forward_edge[0], -1, p0[1] - forward_edge[0] * p0[0]]
            else:
                # across p3
                if forward_edge[1] == 0:
                    forward_opposite = [1, 0, -p3[0]]
                else:
                    forward_opposite = [forward_edge[0], -1, p3[1] - forward_edge[0] * p3[0]]
            new_p0 = line_cross_point(forward_opposite, edge)
            new_p3 = line_cross_point(forward_opposite, edge_opposite)
            fitted_parallelograms.append([new_p0, new_p1, new_p2, new_p3, new_p0])
            # or move backward edge
            new_p0 = p0
            new_p1 = p1
            new_p2 = p2
            new_p3 = p3
            new_p3 = line_cross_point(backward_edge, edge_opposite) #p2-p3直線與p0-p3的交點
            if point_dist_to_line(p0, p3, p1) > point_dist_to_line(p0, p3, p2):
                # across p1
                if backward_edge[1] == 0:
                    backward_opposite = [1, 0, -p1[0]]
                else:
                    backward_opposite = [backward_edge[0], -1, p1[1] - backward_edge[0] * p1[0]]
            else:
                # across p2
                if backward_edge[1] == 0:
                    backward_opposite = [1, 0, -p2[0]]
                else:
                    backward_opposite = [backward_edge[0], -1, p2[1] - backward_edge[0] * p2[0]]
            new_p1 = line_cross_point(backward_opposite, edge)
            new_p2 = line_cross_point(backward_opposite, edge_opposite)
            fitted_parallelograms.append([new_p0, new_p1, new_p2, new_p3, new_p0])
        areas = [Polygon(t).area for t in fitted_parallelograms] #計算所有拟合四邊形的面積
        parallelogram = np.array(fitted_parallelograms[np.argmin(areas)][:-1], dtype=np.float32) #取最小的
        # sort thie polygon
        parallelogram_coord_sum = np.sum(parallelogram, axis=1) #分别對x和y坐标求和
        min_coord_idx = np.argmin(parallelogram_coord_sum) #x,y的極小值
        parallelogram = parallelogram[
            [min_coord_idx, (min_coord_idx + 1) % 4, (min_coord_idx + 2) % 4, (min_coord_idx + 3) % 4]]

        rectange = rectangle_from_parallelogram(parallelogram) #最小外接旋轉矩形
        rectange, rotate_angle = sort_rectangle(rectange)#最小外接旋轉矩形頂點順時針排序

        p0_rect, p1_rect, p2_rect, p3_rect = rectange
        for y, x in xy_in_poly:
            point = np.array([x, y], dtype=np.float32)
            # top
            geo_map[y, x, 0] = point_dist_to_line(p0_rect, p1_rect, point)
            # right
            geo_map[y, x, 1] = point_dist_to_line(p1_rect, p2_rect, point)
            # down
            geo_map[y, x, 2] = point_dist_to_line(p2_rect, p3_rect, point)
            # left
            geo_map[y, x, 3] = point_dist_to_line(p3_rect, p0_rect, point)
            # angle
            geo_map[y, x, 4] = rotate_angle

    shrinked_poly_mask = (shrinked_poly_mask > 0).astype('uint8')
    text_region_boundary_training_mask = 1 - (orig_poly_mask - shrinked_poly_mask)

    return score_map, geo_map, overly_small_text_region_training_mask, text_region_boundary_training_mask
def shrink_poly(poly, r):
    '''
    fit a poly inside the origin poly, maybe bugs here...
    used for generating the score map
    :param poly: the text poly
    :param r: r in the paper
    :return: the shrinked poly
    '''
    # shrink ratio
    R = 0.3
    # find the longer pair
    if np.linalg.norm(poly[0] - poly[1]) + np.linalg.norm(poly[2] - poly[3]) > \
                    np.linalg.norm(poly[0] - poly[3]) + np.linalg.norm(poly[1] - poly[2]):
        # first move (p0, p1), (p2, p3), then (p0, p3), (p1, p2)
        ## p0, p1
        theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0]))
        poly[0][0] += R * r[0] * np.cos(theta)
        poly[0][1] += R * r[0] * np.sin(theta)
        poly[1][0] -= R * r[1] * np.cos(theta)
        poly[1][1] -= R * r[1] * np.sin(theta)
        ## p2, p3
        theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0]))
        poly[3][0] += R * r[3] * np.cos(theta)
        poly[3][1] += R * r[3] * np.sin(theta)
        poly[2][0] -= R * r[2] * np.cos(theta)
        poly[2][1] -= R * r[2] * np.sin(theta)
        ## p0, p3
        theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1]))
        poly[0][0] += R * r[0] * np.sin(theta)
        poly[0][1] += R * r[0] * np.cos(theta)
        poly[3][0] -= R * r[3] * np.sin(theta)
        poly[3][1] -= R * r[3] * np.cos(theta)
        ## p1, p2
        theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1]))
        poly[1][0] += R * r[1] * np.sin(theta)
        poly[1][1] += R * r[1] * np.cos(theta)
        poly[2][0] -= R * r[2] * np.sin(theta)
        poly[2][1] -= R * r[2] * np.cos(theta)
    else:
        ## p0, p3
        theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1]))
        poly[0][0] += R * r[0] * np.sin(theta)
        poly[0][1] += R * r[0] * np.cos(theta)
        poly[3][0] -= R * r[3] * np.sin(theta)
        poly[3][1] -= R * r[3] * np.cos(theta)
        ## p1, p2
        theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1]))
        poly[1][0] += R * r[1] * np.sin(theta)
        poly[1][1] += R * r[1] * np.cos(theta)
        poly[2][0] -= R * r[2] * np.sin(theta)
        poly[2][1] -= R * r[2] * np.cos(theta)
        ## p0, p1
        theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0]))
        poly[0][0] += R * r[0] * np.cos(theta)
        poly[0][1] += R * r[0] * np.sin(theta)
        poly[1][0] -= R * r[1] * np.cos(theta)
        poly[1][1] -= R * r[1] * np.sin(theta)
        ## p2, p3
        theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0]))
        poly[3][0] += R * r[3] * np.cos(theta)
        poly[3][1] += R * r[3] * np.sin(theta)
        poly[2][0] -= R * r[2] * np.cos(theta)
        poly[2][1] -= R * r[2] * np.sin(theta)
    return poly
def fit_line(p1, p2):
    # fit a line ax+by+c = 0
    if p1[0] == p1[1]:
        return [1., 0., -p1[0]]
    else:
        [k, b] = np.polyfit(p1, p2, deg=1)
        return [k, -1., b]
           
EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector
EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector
EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector
EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector
EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector
EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector
EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector
EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector
EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector

損失函數:

EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector

λg 用于平衡分類損失和回歸損失,設為1

在大多數最先進的檢測算法中,訓練圖像通過平衡采樣和hard negative mining 精心處理,以解決目标物體的不平衡分布[9,28]。這樣做可能會提高網絡性能。然而,使用這些技術不可避免地引入了不可微分的階段和更多的參數來調諧和更複雜的pipline,這與我們的設計原理相沖突。

為友善更簡單的訓練程式,使用[38]中引入的平衡交叉熵,由下式給出:

EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector

參數β是正樣本和負樣本之間的平衡因子,由下式給出

EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector

Y*為GT,Y^為預測值

這種平衡的交叉熵損失首先在Yao等人的文本檢測中被采用。 [41]作為得分圖預測的目标函數。它在實踐中運作良好。

以下代碼為激活函數為sigmoid時的損失(dice loss),它的收斂速度會比類平衡交叉熵快:

EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector
def east_dice_loss(overly_small_text_region_training_mask, text_region_boundary_training_mask, loss_weight, small_text_weight):
    def loss(y_true, y_pred):
        eps = 1e-5
        _training_mask = keras.backend.minimum(overly_small_text_region_training_mask + small_text_weight, 1) * text_region_boundary_training_mask
        intersection = backend.reduce_sum(y_true * y_pred * _training_mask)
        union = backend.reduce_sum(y_true * _training_mask) + backend.reduce_sum(y_pred * _training_mask) + eps
        loss = 1. - (2. * intersection / union)
        return loss * loss_weight
    return loss
           
if multi_gpu > 1:
        from keras.utils import multi_gpu_model
        with tf.device('/cpu:0'):
            model = model_with_weights(backbone_east(modifier=modifier, config=config),
                                       weights=weights, skip_mismatch=True)
        training_model = multi_gpu_model(model, gpus=multi_gpu)
    else:
        model          = model_with_weights(backbone_east(modifier=modifier, config=config),
                                        weights=weights, skip_mismatch=True)
        training_model = model

    # make prediction model
    prediction_model = convert_model(model, 'east')

    score_map_loss_weight = keras.backend.variable(0.01, name='score_map_loss_weight')
    small_text_weight = keras.backend.variable(0., name='small_text_weight')

    # compile model
    training_model.compile(
        loss=[losses.east_dice_loss(model.inputs[1], model.inputs[2], score_map_loss_weight, small_text_weight),
              losses.east_rbox_loss(model.inputs[1], model.inputs[2], small_text_weight, model.inputs[3])],
        loss_weights=[1., 1.],
        optimizer=keras.optimizers.adam(lr=lr, clipnorm=0.001)
    )

    return model, training_model, prediction_model
           

 回歸損失:

文本檢測的一個挑戰是自然場景圖像中文本的大小差别很大。直接使用L1或L2損失進行回歸将指導損失偏向更大和更長的文本區域。由于我們需要為大文本區域和小文本區域生成準确的文本幾何預測,是以回歸損失應該是尺寸不變的。是以,我們再RBOX的AABB部分中的采用IoU-loss,以及再QUAD回歸采用尺度标準化平滑L1損失。

RBOX 對于AABB部分,我們采用[46]中的IoU損失,因為它對不同尺度的物體是不變的。

EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector

其中R代表預測的AABB幾何,R * 代表其相應的ground truth。 很容易看出相交矩形的寬度和高度 |R∩R* | 是

EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector

對應的聯合區域是

EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector

是以,可以容易地計算交叉/聯合區域。 接下來,旋轉角度的損失計算為

EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector

其中θ是對旋轉角度的預測,θ*表示GT。 最後,整體幾何損失是AABB損失和角度損失的權重和,由下式給出

EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector

其中λθ在我們的實驗中設定為10。(以下代碼設定為20)

請注意,無論旋轉角度如何,我們都會計算LAABB 。 當角度被完美預測時,這可以看作是四邊形IoU的近似值。 雖然在訓練期間并非如此,但它仍然可以為網絡施加正确的梯度以學習預測R.

def east_rbox_loss(overly_small_text_region_training_mask, text_region_boundary_training_mask, small_text_weight, target_score_map):
    def loss(y_true, y_pred):
        # d1 -> top, d2->right, d3->bottom, d4->left
        d1_gt, d2_gt, d3_gt, d4_gt, theta_gt = backend.split(value=y_true, num_or_size_splits=5, axis=3)
        d1_pred, d2_pred, d3_pred, d4_pred, theta_pred = backend.split(value=y_pred, num_or_size_splits=5, axis=3)
        area_gt = (d1_gt + d3_gt) * (d2_gt + d4_gt)
        area_pred = (d1_pred + d3_pred) * (d2_pred + d4_pred)
        w_union = keras.backend.minimum(d2_gt, d2_pred) + keras.backend.minimum(d4_gt, d4_pred)
        h_union = keras.backend.minimum(d1_gt, d1_pred) + keras.backend.minimum(d3_gt, d3_pred)
        area_intersect = w_union * h_union
        area_union = area_gt + area_pred - area_intersect
        L_AABB = -keras.backend.log((area_intersect + 1.0)/(area_union + 1.0))
        L_theta = 1 - keras.backend.cos(theta_pred - theta_gt)
        L_g = L_AABB + 20 * L_theta
        _training_mask = keras.backend.minimum(overly_small_text_region_training_mask + small_text_weight, 1) * text_region_boundary_training_mask
        return backend.reduce_mean(L_g * target_score_map * _training_mask)
    return loss
           

附Geo損失(代碼未實作):

通過添加為單詞四邊形設計的額外歸一化項來擴充smooth-L1損失,這通常在一個更長方向上。 設Q的所有坐标值都是有序集:

EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector
EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector

其中歸一化項NQ *是四邊形的短邊長度,由下式給出

EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector

且PQ是具有不同頂點排序的Q *的所有等效四邊形的集合。 由于公共訓練資料集中的四邊形注釋不一緻,是以需要這種排序排列。

模型預測:

class East:
    def __init__(self,
                 weight_path,
                 show_log = True,
                 ):
        self.weight_path = weight_path
        self.show_log = show_log

        self.model = ResNetBackbone("resnet50").east(isTest=True, input_size=512)
        if self.show_log:
            print(self.model.summary())
        self.model.load_weights(self.weight_path, by_name=True)

    def detect_img_path(self, path):
        image = read_image_bgr(path)
        return self.detect_rgbimg(image)

    def detect_rgbimg(self,img):
        img_resized, (ratio_h, ratio_w) = self.resize_image(img)
        img_resized = preprocess_image(img_resized, mode='tf-1to1')

        timer = {'net': 0, 'restore': 0, 'nms': 0}
        start = time.time()
        score_map, geo_map = self.model.predict(img_resized[np.newaxis, :, :, :])

        timer['net'] = time.time() - start

        boxes, timer = self.detect(score_map=score_map, geo_map=geo_map, timer=timer)
        if boxes is not None:
            boxes = boxes[:, :8].reshape((-1, 4, 2))
            boxes[:, :, 0] /= ratio_w
            boxes[:, :, 1] /= ratio_h
        duration = time.time() - start
        print('[timing] {}'.format(duration))
        return boxes

    def detect_with_show(self, path):
        image = read_image_rgb(path)
        boxes = self.detect_rgbimg(image)
        for box in boxes:
            # to avoid submitting errors
            box = self.sort_poly(box.astype(np.int32))
            if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3] - box[0]) < 5:
                continue
            cv2.polylines(image, [box.astype(np.int32).reshape((-1, 1, 2))], True, color=(255, 255, 0),
                          thickness=2)

        plt.imshow(image)
        plt.show()

    def resize_image(self, im, max_side_len=2400):
        '''
        resize image to a size multiple of 32 which is required by the network
        :param im: the resized image
        :param max_side_len: limit of max image size to avoid out of memory in gpu
        :return: the resized image and the resize ratio
        '''
        h, w, _ = im.shape

        resize_w = w
        resize_h = h

        # limit the max side
        if max(resize_h, resize_w) > max_side_len:
            ratio = float(max_side_len) / resize_h if resize_h > resize_w else float(max_side_len) / resize_w
        else:
            ratio = 1.
        resize_h = int(resize_h * ratio)
        resize_w = int(resize_w * ratio)

        resize_h = resize_h if resize_h % 32 == 0 else (resize_h // 32) * 32
        resize_w = resize_w if resize_w % 32 == 0 else (resize_w // 32) * 32
        im = cv2.resize(im, (int(resize_w), int(resize_h)))

        ratio_h = resize_h / float(h)
        ratio_w = resize_w / float(w)

        return im, (ratio_h, ratio_w)

    def sort_poly(self, p):
        min_axis = np.argmin(np.sum(p, axis=1))
        p = p[[min_axis, (min_axis + 1) % 4, (min_axis + 2) % 4, (min_axis + 3) % 4]]
        if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]):
            return p
        else:
            return p[[0, 3, 2, 1]]

    def detect(self, score_map, geo_map, timer, score_map_thresh=0.8, box_thresh=0.1, nms_thres=0.2):
        '''
        restore text boxes from score map and geo map
        :param score_map:
        :param geo_map:
        :param timer:
        :param score_map_thresh: threshhold for score map
        :param box_thresh: threshhold for boxes
        :param nms_thres: threshold for nms
        :return:
        '''
        if len(score_map.shape) == 4:
            score_map = score_map[0, :, :, 0]
            geo_map = geo_map[0, :, :, ]
        # filter the score map
        xy_text = np.argwhere(score_map > score_map_thresh)
        # sort the text boxes via the y axis
        xy_text = xy_text[np.argsort(xy_text[:, 0])]
        # restore
        start = time.time()
        text_box_restored = restore_rectangle(xy_text[:, ::-1] * 4, geo_map[xy_text[:, 0], xy_text[:, 1], :])  # N*4*2
        print('{} text boxes before nms'.format(text_box_restored.shape[0]))
        boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
        boxes[:, :8] = text_box_restored.reshape((-1, 8))
        boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
        timer['restore'] = time.time() - start
        # nms part
        start = time.time()
        boxes = nms_locality(boxes.astype(np.float64), nms_thres)
        # boxes = lanms.merge_quadrangle_n9(boxes.astype('float32'), nms_thres)
        timer['nms'] = time.time() - start

        if boxes.shape[0] == 0:
            return None, timer

        # here we filter some low score boxes by the average score map, this is different from the orginal paper
        for i, box in enumerate(boxes):
            mask = np.zeros_like(score_map, dtype=np.uint8)
            cv2.fillPoly(mask, box[:8].reshape((-1, 4, 2)).astype(np.int32) // 4, 1)
            boxes[i, 8] = cv2.mean(score_map, mask)[0]
        boxes = boxes[boxes[:, 8] > box_thresh]

        return boxes, timer

def restore_rectangle_rbox(origin, geometry):
    d = geometry[:, :4]
    angle = geometry[:, 4]
    # for angle > 0
    origin_0 = origin[angle >= 0]
    d_0 = d[angle >= 0]
    angle_0 = angle[angle >= 0]
    if origin_0.shape[0] > 0:
        p = np.array([np.zeros(d_0.shape[0]), -d_0[:, 0] - d_0[:, 2],
                      d_0[:, 1] + d_0[:, 3], -d_0[:, 0] - d_0[:, 2],
                      d_0[:, 1] + d_0[:, 3], np.zeros(d_0.shape[0]),
                      np.zeros(d_0.shape[0]), np.zeros(d_0.shape[0]),
                      d_0[:, 3], -d_0[:, 2]])
        p = p.transpose((1, 0)).reshape((-1, 5, 2))  # N*5*2

        rotate_matrix_x = np.array([np.cos(angle_0), np.sin(angle_0)]).transpose((1, 0))
        rotate_matrix_x = np.repeat(rotate_matrix_x, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1))  # N*5*2

        rotate_matrix_y = np.array([-np.sin(angle_0), np.cos(angle_0)]).transpose((1, 0))
        rotate_matrix_y = np.repeat(rotate_matrix_y, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1))

        p_rotate_x = np.sum(rotate_matrix_x * p, axis=2)[:, :, np.newaxis]  # N*5*1
        p_rotate_y = np.sum(rotate_matrix_y * p, axis=2)[:, :, np.newaxis]  # N*5*1

        p_rotate = np.concatenate([p_rotate_x, p_rotate_y], axis=2)  # N*5*2

        p3_in_origin = origin_0 - p_rotate[:, 4, :]
        new_p0 = p_rotate[:, 0, :] + p3_in_origin  # N*2
        new_p1 = p_rotate[:, 1, :] + p3_in_origin
        new_p2 = p_rotate[:, 2, :] + p3_in_origin
        new_p3 = p_rotate[:, 3, :] + p3_in_origin

        new_p_0 = np.concatenate([new_p0[:, np.newaxis, :], new_p1[:, np.newaxis, :],
                                  new_p2[:, np.newaxis, :], new_p3[:, np.newaxis, :]], axis=1)  # N*4*2
    else:
        new_p_0 = np.zeros((0, 4, 2))
    # for angle < 0
    origin_1 = origin[angle < 0]
    d_1 = d[angle < 0]
    angle_1 = angle[angle < 0]
    if origin_1.shape[0] > 0:
        p = np.array([-d_1[:, 1] - d_1[:, 3], -d_1[:, 0] - d_1[:, 2],
                      np.zeros(d_1.shape[0]), -d_1[:, 0] - d_1[:, 2],
                      np.zeros(d_1.shape[0]), np.zeros(d_1.shape[0]),
                      -d_1[:, 1] - d_1[:, 3], np.zeros(d_1.shape[0]),
                      -d_1[:, 1], -d_1[:, 2]])
        p = p.transpose((1, 0)).reshape((-1, 5, 2))  # N*5*2

        rotate_matrix_x = np.array([np.cos(-angle_1), -np.sin(-angle_1)]).transpose((1, 0))
        rotate_matrix_x = np.repeat(rotate_matrix_x, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1))  # N*5*2

        rotate_matrix_y = np.array([np.sin(-angle_1), np.cos(-angle_1)]).transpose((1, 0))
        rotate_matrix_y = np.repeat(rotate_matrix_y, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1))

        p_rotate_x = np.sum(rotate_matrix_x * p, axis=2)[:, :, np.newaxis]  # N*5*1
        p_rotate_y = np.sum(rotate_matrix_y * p, axis=2)[:, :, np.newaxis]  # N*5*1

        p_rotate = np.concatenate([p_rotate_x, p_rotate_y], axis=2)  # N*5*2

        p3_in_origin = origin_1 - p_rotate[:, 4, :]
        new_p0 = p_rotate[:, 0, :] + p3_in_origin  # N*2
        new_p1 = p_rotate[:, 1, :] + p3_in_origin
        new_p2 = p_rotate[:, 2, :] + p3_in_origin
        new_p3 = p_rotate[:, 3, :] + p3_in_origin

        new_p_1 = np.concatenate([new_p0[:, np.newaxis, :], new_p1[:, np.newaxis, :],
                                  new_p2[:, np.newaxis, :], new_p3[:, np.newaxis, :]], axis=1)  # N*4*2
    else:
        new_p_1 = np.zeros((0, 4, 2))
    return np.concatenate([new_p_0, new_p_1])
def restore_rectangle(origin, geometry):
    return restore_rectangle_rbox(origin, geometry)
           

以上,主要根據預測得到的得分圖和幾何圖計算得到文本區域的外接旋轉矩形

算法檢測結果:

EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector
EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector
EAST OCR目标檢測及源碼EAST: An Efficient and Accurate Scene Text Detector

參考文章:

https://blog.csdn.net/qq_34886403/article/details/86710446

原文位址:

https://arxiv.org/pdf/1704.03155v2.pdf

繼續閱讀