天天看点

深度学习目标检测图片标记工具

import os
import cv2
import sys
import argparse
import numpy as np

"""
1.遍历根目录下的图片,如果已经绘制完毕则加载其ground truth画出
2.显示一张图片,while让动作持续
3.左键点击切换至画图形模式(drawing = True),点击四下形成一个四边形,画出图形并保存四个点坐标到临时列表
4.画完任意一个多边形的时候(drawing = False),点击左键如果点击位置再某个四边形内,可将该四边形删除,更新临时列表
5.点击空格键完成当前图片编辑,保存临时为该图ground truth,切换至下一张
"""

KEY_EMPTY=0
KEY_SPACE = 32  # 空格键完成本图编辑切换至下一张
KEY_ESC=27
extensions = ['jpg','jepg','png']


parser = argparse.ArgumentParser(description="your script description")
parser.add_argument('--browse', '-b', action='store_true', help='browse mode') 


class mylabel:
    
    def __init__(self,filepath,labelpath):
        self.filepath = filepath  # 图片路径
        self.labelpath = labelpath
        self.vertex_num = 0  # 这是第0个坐标
        self.gt_boxes = []  # 所有box
        self.temp_box = [[[-1,-1],[-1,-1],[-1,-1],[-1,-1]]]  # 一个box
        self.temp_pt = (-1,-1)
    
    
    def _isRayIntersectsSegment(self,poi,s_poi,e_poi):  #[x,y] [lng,lat]
        """
            判断射线与边界的交点情况
            :param poi: 需要判断的点,[x,y]
            :param s_poi: 边界起点
            :param e_poi: 边界终点
            :return: 有交点则返回True,无则返回False
        """
        # 输入:判断点,边起点,边终点,都是[lng,lat]格式数组
        if s_poi[1]==e_poi[1]: #排除与射线平行、重合,线段首尾端点重合的情况
            return False
        if s_poi[1]>poi[1] and e_poi[1]>poi[1]:  # 线段在射线上边
            return False
        if s_poi[1]<poi[1] and e_poi[1]<poi[1]:  # 线段在射线下边
            return False
        if s_poi[1]==poi[1] and e_poi[1]>poi[1]:  # 交点为下端点,对应spoint
            return False
        if e_poi[1]==poi[1] and s_poi[1]>poi[1]:  # 交点为下端点,对应epoint
            return False
        if s_poi[0]<poi[0] and e_poi[1]<poi[1]:  # 线段在射线左边
            return False

        xseg=e_poi[0]-(e_poi[0]-s_poi[0])*(e_poi[1]-poi[1])/(e_poi[1]-s_poi[1])  # 求交
        if xseg<poi[0]:  # 交点在射线起点的左侧
            return False
        return True  # 排除上述情况之后

    
    def _isPoiWithinPoly(self,poi,box):
        """
            判断射线与box所有边界的交点情况
            :param poi: 需要判断的点,[x,y]
            :param box: 一个多边形
            :return: 点在多边形内则返回True,否则返回False
        """
        sinsc=0 #交点个数
        for i in range(4):
            start_index = i
            end_index = (i + 1)%4
            start_poi = box[start_index]
            end_poi = box[end_index]
            if self._isRayIntersectsSegment(poi,start_poi,end_poi):
                sinsc+=1 # 有交点就加1
        return True if sinsc%2==1 else  False
    
    
    def _drawLine(self,img):
        vertex_num = self.vertex_num % 4
        box = self.temp_box[0]
        if vertex_num<1:
            return img
        for i in range(vertex_num-1):
            start_vertex = (box[i][0],box[i][1])
            end_vertex = (box[i+1][0],box[i+1][1])
            cv2.line(img,start_vertex,end_vertex,(0, 255, 0),2)
        move_start_vertex = (box[vertex_num-1][0],box[vertex_num-1][1])
        move_end_vertex = self.temp_pt
        cv2.line(img,move_start_vertex,move_end_vertex,(0, 255, 0),2)
        return img
            
            
    
    def _drawBox(self,img):
        """
            绘制所有box
            :param img: 图像副本作为画布
            :return img: 绘制完box的图像副本
        """
        gt_boxes = np.array(self.gt_boxes)
        for box in gt_boxes:
            cv2.polylines(img, box, True, (0,0,255),2)
        return img
    
    
    def _mouseHandler(self, event, x, y, flags, param):
        """
            鼠标响应事件
        """
        if event == cv2.EVENT_LBUTTONDOWN:  # 点击左键的时候绘四边形
            vertex_num = self.vertex_num % 4
            self.temp_box[0][vertex_num][0] = x
            self.temp_box[0][vertex_num][1]=y
            if vertex_num == 3:  # 画完一个四边形了
                self.gt_boxes.append(self.temp_box)
                self.temp_box = [[[-1,-1],[-1,-1],[-1,-1],[-1,-1]]]
            self.vertex_num += 1
        elif event == cv2.EVENT_MOUSEMOVE:
            self.temp_pt = (x,y)
        elif event == cv2.EVENT_RBUTTONUP:  # 点击右键的时候删除四边形,判断一个点在四边形中
            for box in self.gt_boxes:  # 遍历所有box,判断点是否在其内,在的话删除该box
                if self._isPoiWithinPoly([x,y],box[0]):
                    self.gt_boxes.remove(box)
                    
                    
    def _saveBox(self):
        with open(self.labelpath,'w') as txt:
            contents = ""
            for box in self.gt_boxes:
                box = box[0]
                i = 0
                for vertex in box:
                    contents += str(vertex[0]) + ","
                    if i!=3:
                        contents += str(vertex[1]) + ","
                        i += 1
                    else:
                        contents += str(vertex[1])
                contents+="\n"
            txt.write(contents)
    
    
    def browseBox(self):
        img = cv2.imread(self.filepath)
        cv2.namedWindow(self.filepath)
        with open(self.labelpath,'r') as label:
            boxes = label.readlines()
            gt_boxes = []  # 所有boxes
            for box in boxes:
                box = box.strip("\n")
                box = box.split(",")
                pt1 = [int(box[0]),int(box[1])]
                pt2 = [int(box[2]),int(box[3])]
                pt3 = [int(box[4]),int(box[5])]
                pt4 = [int(box[6]),int(box[7])]
                gt_box=[pt1,pt2,pt3,pt4]
                gt_boxes.append(gt_box)
            self.gt_boxes = [gt_boxes]
            canvas = self._drawBox(img)
            cv2.imshow(self.filepath, canvas)
            key = cv2.waitKey()
            if key == KEY_SPACE:
                cv2.destroyAllWindows()


    def label(self):
        img = cv2.imread(self.filepath)
        cv2.namedWindow("mylabel")
        cv2.setMouseCallback("mylabel", self._mouseHandler)
        key = KEY_EMPTY
        delay = int(1000/20)

        while key != KEY_SPACE:
            if key == KEY_ESC:  # ESC退出整个程序
                cv2.destroyAllWindows()
                sys.exit()
            canvas = img.copy()  # 获取图像副本
            canvas = self._drawBox(canvas)
            canvas = self._drawLine(canvas)
            cv2.imshow("mylabel",canvas)
            key = cv2.waitKey(delay)
        
        if key == KEY_SPACE:
            self._saveBox()
            cv2.destroyAllWindows()

    
            
if __name__ == "__main__":
    args = parser.parse_args()
    root_dir = ".\samples"
    files = os.listdir(root_dir)
    for file in files:
        extension = file.split(".")[-1]
        if extension not in extensions:
            continue
        filename = file.split(".")[0]
        filepath = os.path.join(root_dir,file)
        labelpath = os.path.join(root_dir,filename+".txt")
        is_labeled = True if os.path.exists(labelpath) else False  # 判断文件是否已经标记
        if args.browse and is_labeled:
            print(filepath)
            ml = mylabel(filepath,labelpath)
            ml.browseBox()
        elif is_labeled:
            continue
        else:
            ml = mylabel(filepath,labelpath)
            ml.label()
           

继续阅读