yolo系列是目辨別别的重頭戲,為了更好的了解掌握它,我們必須從源碼出發深刻了解代碼。下面我們來講解pytorch實作的yolov3源碼。在講解之前,大家應該具備相應的原理知識yolov1,yolov2,yolov3。
大部分同學在看論文時并不能把所有的知識全部掌握。我們必須結合代碼(代碼将理論變成實踐),它是百分百還原理論的,也隻有在掌握代碼以及理論後,我們才能推陳出新有所收獲,是以大家平時一定多接觸代碼,這裡我們會結合yolov3的理論知識讓大家真正在代碼中了解思想。
下面我就train過程的代碼進行講解。在了解train過程之前,建議大家先了解inference的代碼講解。
PyTorch實作yolov3代碼詳細解密
資料讀取:
Pytorch讀取圖檔,主要通過Dataset類,Dataset類作為所有的datasets的基類存在,所有的datasets都需要繼承它。
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
這裡重點看getitem函數,getitem接收一個index,傳回圖檔資料和labels。我們看yolov3的dataset。
dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True)
class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=416, augment=False):
with open(path, 'r') as file:
img_files = file.read().splitlines()
self.img_files = list(filter(lambda x: len(x) > 0, img_files))
n = len(self.img_files)
assert n > 0, 'No images found in %s' % path
self.img_size = img_size
self.augment = augment
self.label_files = [
x.replace('images', 'labels').replace('.bmp', '.txt').replace('.jpg', '.txt').replace('.png', '.txt')
for x in self.img_files]
# if n < 200: # preload all images into memory if possible
# self.imgs = [cv2.imread(img_files[i]) for i in range(n)]
def __len__(self):
return len(self.img_files)
def __getitem__(self, index):
img_path = self.img_files[index]
label_path = self.label_files[index]
# if hasattr(self, 'imgs'):
# img = self.imgs[index] # BGR
img = cv2.imread(img_path) # BGR
assert img is not None, 'File Not Found ' + img_path
h, w, _ = img.shape
img, ratio, padw, padh = letterbox(img, height=self.img_size)
#将每幅圖resize成418*418
# Load labels
labels = []
if os.path.isfile(label_path):
with open(label_path, 'r') as file:
lines = file.read().splitlines()
x = np.array([x.split() for x in lines], dtype=np.float32)
if x.size > 0:
# Normalized xywh to pixel xyxy format
labels = x.copy()
labels[:, 1] = ratio * w * (x[:, 1] - x[:, 3] / 2) + padw
labels[:, 2] = ratio * h * (x[:, 2] - x[:, 4] / 2) + padh
labels[:, 3] = ratio * w * (x[:, 1] + x[:, 3] / 2) + padw
labels[:, 4] = ratio * h * (x[:, 2] + x[:, 4] / 2) + padh
print(labels)
# Augment image and labels
if self.augment:
img, labels = random_affine(img, labels, degrees=(-5, 5), translate=(0.10, 0.10), scale=(0.90, 1.10))
nL = len(labels) # number of labels
if nL:
# convert xyxy to xywh
labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) / self.img_size
if self.augment:
# random left-right flip
lr_flip = True
if lr_flip and random.random() > 0.5:
img = np.fliplr(img)
if nL:
labels[:, 1] = 1 - labels[:, 1]
# random up-down flip
ud_flip = False
if ud_flip and random.random() > 0.5:
img = np.flipud(img)
if nL:
labels[:, 2] = 1 - labels[:, 2]
labels_out = torch.zeros((nL, 6))
if nL:
labels_out[:, 1:] = torch.from_numpy(labels)
# Normalize
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
img = np.ascontiguousarray(img, dtype=np.float32) # uint8 to float32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
return torch.from_numpy(img), labels_out, img_path, (h, w)
dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True),可以看到其中LoadImagesAndLabels類是Dataset的子類,init函數是正常的讀取資料,我們主要看getitem,getitem接收一個index,就是img_files的索引,通過letterbox函數進行資料預處理将每幅圖resize成418*418,labels裡面存放的是ground truth的類别和坐标資訊,因為圖像resize了,是以labels中的坐标資訊也要相對變化。最後傳回處理後的img,labels,位址和寬高。
那麼讀取自己資料的基本流程就是:
1:制作存儲了圖像的路徑和标簽資訊的txt
2:将這些資訊轉化為list,該list每一個元素對應一個樣本
3:通過getitem函數,讀取資料标簽,并傳回。
在訓練代碼裡是感覺不到這些操作的,隻會看到通過DataLoader就可以擷取一個batch的資料,其實觸發去讀取圖檔這些操作的是DataLoader裡的__iter__(self),流程較長的描述如下:
1.從dataset類中初始化txt,txt中有圖檔路徑和标簽
2.初始化DataLoder時,将dataset傳入,進而使DataLoader擁有圖檔路徑
3.在for i, (imgs, targets, _, _) in enumerate(dataloader):中,一個iteration進行時,讀取一個batch的資料,enumerate将資料傳回到imgs,targets中,imgs就是資料增強後的圖像,labels就是處理後的标簽。
4.讀取過程中需要在class DataLoader()類中調用_DataLoderIter()
5.在 _DataLoderiter()類中跳到 next(self)函數,在該函數中通過indices = next(self.sample_iter)擷取一個batch的indices,再通過batch=self.collate_fn()擷取一個batch資料。
6.self.collate_fn中調用LoadImagesAndLabels類中的 getitem()函數,再函數中擷取圖檔。
如此,我們第一步資料預處理就完成了,後面我們就可以把資料imgs放到模型裡跑了。大家不要忽視這些代碼,想真正弄懂,我們就要一步一步刨根問底。
下面一章,我們會根據程式複現訓練過程的算法原理,講解yolov3的loss是如何計算的。