天天看點

Pytorch實作yolov3(train)訓練代碼詳解(一)

yolo系列是目辨別别的重頭戲,為了更好的了解掌握它,我們必須從源碼出發深刻了解代碼。下面我們來講解pytorch實作的yolov3源碼。在講解之前,大家應該具備相應的原理知識yolov1,yolov2,yolov3。

大部分同學在看論文時并不能把所有的知識全部掌握。我們必須結合代碼(代碼将理論變成實踐),它是百分百還原理論的,也隻有在掌握代碼以及理論後,我們才能推陳出新有所收獲,是以大家平時一定多接觸代碼,這裡我們會結合yolov3的理論知識讓大家真正在代碼中了解思想。

下面我就train過程的代碼進行講解。在了解train過程之前,建議大家先了解inference的代碼講解。

PyTorch實作yolov3代碼詳細解密

資料讀取:

Pytorch讀取圖檔,主要通過Dataset類,Dataset類作為所有的datasets的基類存在,所有的datasets都需要繼承它。

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])
           

這裡重點看getitem函數,getitem接收一個index,傳回圖檔資料和labels。我們看yolov3的dataset。

dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True)
class LoadImagesAndLabels(Dataset):  # for training/testing
    def __init__(self, path, img_size=416, augment=False):
        with open(path, 'r') as file:
            img_files = file.read().splitlines()
            self.img_files = list(filter(lambda x: len(x) > 0, img_files))

        n = len(self.img_files)
        assert n > 0, 'No images found in %s' % path
        self.img_size = img_size
        self.augment = augment
        self.label_files = [
            x.replace('images', 'labels').replace('.bmp', '.txt').replace('.jpg', '.txt').replace('.png', '.txt')
            for x in self.img_files]

        # if n < 200:  # preload all images into memory if possible
        #    self.imgs = [cv2.imread(img_files[i]) for i in range(n)]

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, index):
        img_path = self.img_files[index]
        label_path = self.label_files[index]

        # if hasattr(self, 'imgs'):
        #    img = self.imgs[index]  # BGR
        img = cv2.imread(img_path)  # BGR
        assert img is not None, 'File Not Found ' + img_path
        h, w, _ = img.shape
        img, ratio, padw, padh = letterbox(img, height=self.img_size)

        #将每幅圖resize成418*418

        # Load labels
        labels = []
        if os.path.isfile(label_path):
            with open(label_path, 'r') as file:
                lines = file.read().splitlines()
            x = np.array([x.split() for x in lines], dtype=np.float32)
            if x.size > 0:
                # Normalized xywh to pixel xyxy format
                labels = x.copy()
                labels[:, 1] = ratio * w * (x[:, 1] - x[:, 3] / 2) + padw
                labels[:, 2] = ratio * h * (x[:, 2] - x[:, 4] / 2) + padh
                labels[:, 3] = ratio * w * (x[:, 1] + x[:, 3] / 2) + padw
                labels[:, 4] = ratio * h * (x[:, 2] + x[:, 4] / 2) + padh
                print(labels)
        # Augment image and labels
        if self.augment:
            img, labels = random_affine(img, labels, degrees=(-5, 5), translate=(0.10, 0.10), scale=(0.90, 1.10))

        nL = len(labels)  # number of labels
        if nL:
            # convert xyxy to xywh
            labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) / self.img_size

        if self.augment:
            # random left-right flip
            lr_flip = True
            if lr_flip and random.random() > 0.5:
                img = np.fliplr(img)
                if nL:
                    labels[:, 1] = 1 - labels[:, 1]

            # random up-down flip
            ud_flip = False
            if ud_flip and random.random() > 0.5:
                img = np.flipud(img)
                if nL:
                    labels[:, 2] = 1 - labels[:, 2]

        labels_out = torch.zeros((nL, 6))
        if nL:
            labels_out[:, 1:] = torch.from_numpy(labels)

        # Normalize
        img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
        img = np.ascontiguousarray(img, dtype=np.float32)  # uint8 to float32
        img /= 255.0  # 0 - 255 to 0.0 - 1.0

        return torch.from_numpy(img), labels_out, img_path, (h, w)
           

dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True),可以看到其中LoadImagesAndLabels類是Dataset的子類,init函數是正常的讀取資料,我們主要看getitem,getitem接收一個index,就是img_files的索引,通過letterbox函數進行資料預處理将每幅圖resize成418*418,labels裡面存放的是ground truth的類别和坐标資訊,因為圖像resize了,是以labels中的坐标資訊也要相對變化。最後傳回處理後的img,labels,位址和寬高。

那麼讀取自己資料的基本流程就是:

1:制作存儲了圖像的路徑和标簽資訊的txt

2:将這些資訊轉化為list,該list每一個元素對應一個樣本

3:通過getitem函數,讀取資料标簽,并傳回。

在訓練代碼裡是感覺不到這些操作的,隻會看到通過DataLoader就可以擷取一個batch的資料,其實觸發去讀取圖檔這些操作的是DataLoader裡的__iter__(self),流程較長的描述如下:

1.從dataset類中初始化txt,txt中有圖檔路徑和标簽

2.初始化DataLoder時,将dataset傳入,進而使DataLoader擁有圖檔路徑

3.在for i, (imgs, targets, _, _) in enumerate(dataloader):中,一個iteration進行時,讀取一個batch的資料,enumerate将資料傳回到imgs,targets中,imgs就是資料增強後的圖像,labels就是處理後的标簽。

4.讀取過程中需要在class DataLoader()類中調用_DataLoderIter()

5.在 _DataLoderiter()類中跳到 next(self)函數,在該函數中通過indices = next(self.sample_iter)擷取一個batch的indices,再通過batch=self.collate_fn()擷取一個batch資料。

6.self.collate_fn中調用LoadImagesAndLabels類中的 getitem()函數,再函數中擷取圖檔。

如此,我們第一步資料預處理就完成了,後面我們就可以把資料imgs放到模型裡跑了。大家不要忽視這些代碼,想真正弄懂,我們就要一步一步刨根問底。

下面一章,我們會根據程式複現訓練過程的算法原理,講解yolov3的loss是如何計算的。

繼續閱讀