天天看点

Pytorch实现yolov3(train)训练代码详解(一)

yolo系列是目标识别的重头戏,为了更好的理解掌握它,我们必须从源码出发深刻理解代码。下面我们来讲解pytorch实现的yolov3源码。在讲解之前,大家应该具备相应的原理知识yolov1,yolov2,yolov3。

大部分同学在看论文时并不能把所有的知识全部掌握。我们必须结合代码(代码将理论变成实践),它是百分百还原理论的,也只有在掌握代码以及理论后,我们才能推陈出新有所收获,所以大家平时一定多接触代码,这里我们会结合yolov3的理论知识让大家真正在代码中理解思想。

下面我就train过程的代码进行讲解。在理解train过程之前,建议大家先了解inference的代码讲解。

PyTorch实现yolov3代码详细解密

数据读取:

Pytorch读取图片,主要通过Dataset类,Dataset类作为所有的datasets的基类存在,所有的datasets都需要继承它。

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])
           

这里重点看getitem函数,getitem接收一个index,返回图片数据和labels。我们看yolov3的dataset。

dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True)
class LoadImagesAndLabels(Dataset):  # for training/testing
    def __init__(self, path, img_size=416, augment=False):
        with open(path, 'r') as file:
            img_files = file.read().splitlines()
            self.img_files = list(filter(lambda x: len(x) > 0, img_files))

        n = len(self.img_files)
        assert n > 0, 'No images found in %s' % path
        self.img_size = img_size
        self.augment = augment
        self.label_files = [
            x.replace('images', 'labels').replace('.bmp', '.txt').replace('.jpg', '.txt').replace('.png', '.txt')
            for x in self.img_files]

        # if n < 200:  # preload all images into memory if possible
        #    self.imgs = [cv2.imread(img_files[i]) for i in range(n)]

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, index):
        img_path = self.img_files[index]
        label_path = self.label_files[index]

        # if hasattr(self, 'imgs'):
        #    img = self.imgs[index]  # BGR
        img = cv2.imread(img_path)  # BGR
        assert img is not None, 'File Not Found ' + img_path
        h, w, _ = img.shape
        img, ratio, padw, padh = letterbox(img, height=self.img_size)

        #将每幅图resize成418*418

        # Load labels
        labels = []
        if os.path.isfile(label_path):
            with open(label_path, 'r') as file:
                lines = file.read().splitlines()
            x = np.array([x.split() for x in lines], dtype=np.float32)
            if x.size > 0:
                # Normalized xywh to pixel xyxy format
                labels = x.copy()
                labels[:, 1] = ratio * w * (x[:, 1] - x[:, 3] / 2) + padw
                labels[:, 2] = ratio * h * (x[:, 2] - x[:, 4] / 2) + padh
                labels[:, 3] = ratio * w * (x[:, 1] + x[:, 3] / 2) + padw
                labels[:, 4] = ratio * h * (x[:, 2] + x[:, 4] / 2) + padh
                print(labels)
        # Augment image and labels
        if self.augment:
            img, labels = random_affine(img, labels, degrees=(-5, 5), translate=(0.10, 0.10), scale=(0.90, 1.10))

        nL = len(labels)  # number of labels
        if nL:
            # convert xyxy to xywh
            labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) / self.img_size

        if self.augment:
            # random left-right flip
            lr_flip = True
            if lr_flip and random.random() > 0.5:
                img = np.fliplr(img)
                if nL:
                    labels[:, 1] = 1 - labels[:, 1]

            # random up-down flip
            ud_flip = False
            if ud_flip and random.random() > 0.5:
                img = np.flipud(img)
                if nL:
                    labels[:, 2] = 1 - labels[:, 2]

        labels_out = torch.zeros((nL, 6))
        if nL:
            labels_out[:, 1:] = torch.from_numpy(labels)

        # Normalize
        img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
        img = np.ascontiguousarray(img, dtype=np.float32)  # uint8 to float32
        img /= 255.0  # 0 - 255 to 0.0 - 1.0

        return torch.from_numpy(img), labels_out, img_path, (h, w)
           

dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True),可以看到其中LoadImagesAndLabels类是Dataset的子类,init函数是正常的读取数据,我们主要看getitem,getitem接收一个index,就是img_files的索引,通过letterbox函数进行数据预处理将每幅图resize成418*418,labels里面存放的是ground truth的类别和坐标信息,因为图像resize了,所以labels中的坐标信息也要相对变化。最后返回处理后的img,labels,地址和宽高。

那么读取自己数据的基本流程就是:

1:制作存储了图像的路径和标签信息的txt

2:将这些信息转化为list,该list每一个元素对应一个样本

3:通过getitem函数,读取数据标签,并返回。

在训练代码里是感觉不到这些操作的,只会看到通过DataLoader就可以获取一个batch的数据,其实触发去读取图片这些操作的是DataLoader里的__iter__(self),流程详细描述如下:

1.从dataset类中初始化txt,txt中有图片路径和标签

2.初始化DataLoder时,将dataset传入,从而使DataLoader拥有图片路径

3.在for i, (imgs, targets, _, _) in enumerate(dataloader):中,一个iteration进行时,读取一个batch的数据,enumerate将数据返回到imgs,targets中,imgs就是数据增强后的图像,labels就是处理后的标签。

4.读取过程中需要在class DataLoader()类中调用_DataLoderIter()

5.在 _DataLoderiter()类中跳到 next(self)函数,在该函数中通过indices = next(self.sample_iter)获取一个batch的indices,再通过batch=self.collate_fn()获取一个batch数据。

6.self.collate_fn中调用LoadImagesAndLabels类中的 getitem()函数,再函数中获取图片。

如此,我们第一步数据预处理就完成了,后面我们就可以把数据imgs放到模型里跑了。大家不要忽视这些代码,想真正弄懂,我们就要一步一步刨根问底。

下面一章,我们会根据程序复现训练过程的算法原理,讲解yolov3的loss是如何计算的。

继续阅读