天天看點

SSD-Pytorch模型訓練自己的資料集here we specify year (07 or 12) and dataset (‘test’, ‘val’, ‘train’)

1.下載下傳SSD-Pytorch代碼

SSD-pytorch代碼連結: https://github.com/amdegroot/ssd.pytorch

git clone https://github.com/amdegroot/ssd.pytorch
           
  1. 運作該代碼下載下傳到本地(如果下載下傳太慢可以上傳到碼雲,然後git clone碼雲位址)

2.準備資料集

  1. 沒有資料集的同學可以下載下傳代碼自帶的VOC和COCO資料集(./data/scripts目錄下)
SSD-Pytorch模型訓練自己的資料集here we specify year (07 or 12) and dataset (‘test’, ‘val’, ‘train’)
  1. 有自己的資料集請将資料集放置在./data目錄下,例如VOC格式資料集,建立VOCdevkit檔案夾,如下圖所示,可以參考:https://blog.csdn.net/qq_34806812/article/details/81673798.
  2. 在Annotations中放置所有的标簽,在JPEGimages中放置所有的圖檔,在ImagesSets/Main中放置train.txt/val.txt/test.txt(内容隻有圖檔的名字,例如:00001,00002,不能帶字尾jpg或者png)等,可以用腳本自己生成:https://blog.csdn.net/GeekYao/article/details/105074574.
SSD-Pytorch模型訓練自己的資料集here we specify year (07 or 12) and dataset (‘test’, ‘val’, ‘train’)

3.根據自己的資料集修改代碼

  1. 部落客用的VOC格式的資料集,下面修改都是以VOC格式為例

3.1 config.py

  1. 找到config.py檔案,
  2. 打開修改VOC中的num_classes,根據自己的情況修改:classes+1(背景算一類),
  3. 我這裡就隻有一類,所有是2
  4. 第一次調試最好修改一下max_iter,不然疊代次數太大,要好長時間,其他都是一些超參數,可以占時不修改

部落客用的VOC格式的資料集,下面修改都是以VOC格式為例

SSD-Pytorch模型訓練自己的資料集here we specify year (07 or 12) and dataset (‘test’, ‘val’, ‘train’)

3.2VOC0712.py

SSD-Pytorch模型訓練自己的資料集here we specify year (07 or 12) and dataset (‘test’, ‘val’, ‘train’)
  1. 根據自己的标簽進行修改,部落客這裡隻有一類,是以隻有一個dargon fruit(注:如果隻有一類,需要加上[ ])
SSD-Pytorch模型訓練自己的資料集here we specify year (07 or 12) and dataset (‘test’, ‘val’, ‘train’)
  1. image_sets中修改一下,根據自己的設定的資料集修改,我這裡隻有train和val

3.3 train.py

SSD-Pytorch模型訓練自己的資料集here we specify year (07 or 12) and dataset (‘test’, ‘val’, ‘train’)

下載下傳預訓練模型。VGG16_reducedfc.pth

連結: https://pan.baidu.com/s/1EW9qT0nJkE2dK7thn_kPVw 密碼: nw6t

–來自百度網盤超級會員V1的分享

SSD-Pytorch模型訓練自己的資料集here we specify year (07 or 12) and dataset (‘test’, ‘val’, ‘train’)
  1. 根據自己的顯存修改batch_size,建議一開始修改小一點,部落客1660ti 6G顯存
SSD-Pytorch模型訓練自己的資料集here we specify year (07 or 12) and dataset (‘test’, ‘val’, ‘train’)
  1. 将儲存訓練模型的參數調低一點,之前iter設定的1000,這裡設定為500,之後根據自己情況在設定
  2. 順便修改一下儲存的模型名字,也可以之後修改,把COCO改成VOC,部落客這裡沒修改

3.4 eval.py

SSD-Pytorch模型訓練自己的資料集here we specify year (07 or 12) and dataset (‘test’, ‘val’, ‘train’)

添加訓練好的模型到eval.py,對模型進行驗證,我這裡訓練好的是ssd300_VOC_500.pth

将下面的

args = parser.parse_args()
           

修改為

args,unknow= parser.parse_known_args()
           

3.5 SSD.py

SSD-Pytorch模型訓練自己的資料集here we specify year (07 or 12) and dataset (‘test’, ‘val’, ‘train’)
  1. 修改num_classes,跟上面config.py中的一緻就行
  2. 修改完成後,運作train.py,完成訓練之後,部落客運作eval.py驗證了訓練的模型,AP隻有63%,可能是部落客資料集太少了

運作eval.py隻能看到AP值,想要測試自己的圖檔,在jupyter notebook中運作demo.ipynb

将對應部分的代碼,修改為以下這樣即可,注意正确添加圖檔的路徑

image = cv2.imread(’…/data/example3.jpg’, cv2.IMREAD_COLOR) # uncomment if dataset not downloaded

from matplotlib import pyplot as plt

from data import VOCDetection, VOC_ROOT, VOCAnnotationTransform

here we specify year (07 or 12) and dataset (‘test’, ‘val’, ‘train’)

#testset = VOCDetection('./data/example1.jpg', [('2020', 'val')], None, VOCAnnotationTransform())
#img_id = 13
#image = testset.pull_image(img_id)
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10,10))
plt.imshow(rgb_image)
plt.show()
           
SSD-Pytorch模型訓練自己的資料集here we specify year (07 or 12) and dataset (‘test’, ‘val’, ‘train’)
SSD-Pytorch模型訓練自己的資料集here we specify year (07 or 12) and dataset (‘test’, ‘val’, ‘train’)

可能會存在的問題:

bug1:出現次元不比對的情況

loc_loss += loss_l.data[0] 報錯

SSD-Pytorch模型訓練自己的資料集here we specify year (07 or 12) and dataset (‘test’, ‘val’, ‘train’)

解決方法:

  1. 将.data[0]改為.item(),下面print中的也改為loss.item()
  2. 建議參考:https://github.com/amdegroot/ssd.pytorch/issues/421

bug2:自動停止訓練

解決方法:

SSD-Pytorch模型訓練自己的資料集here we specify year (07 or 12) and dataset (‘test’, ‘val’, ‘train’)

3. load train data部分修改為如上圖所示

bug3:可能會出現pytorch版本帶來的影響問題

解決方法:根據提示語句,百度修改即可

bug4:運作eval.py可能會出現pytest這種情況

解決方法:将eval.py中的test_net函數名字修改即可,不能出現test關鍵字,部落客修改為set_net成功運作

bug5:訓練出現-nan

解決方法:降低學習率

bug6:出現顯存不足的問題Runtimeout

解決方法:降低batch_size

bug7:出現數組索引過多的情況

IndexError: too many indices for array

解決方法:因為有些标注的标簽沒有資料,所有會出現數組索引出錯

如果資料比較多,可以用如下腳本排查是哪個标簽出現問題(注意修改自己的标簽路徑)

import argparse
import sys
import cv2
import os

import os.path          as osp
import numpy            as np

if sys.version_info[0] == 2:
    import xml.etree.cElementTree as ET
else:
    import xml.etree.ElementTree  as ET


parser    = argparse.ArgumentParser(
            description='Single Shot MultiBox Detector Training With Pytorch')
train_set = parser.add_mutually_exclusive_group()

parser.add_argument('--root', default='data/VOCdevkit/VOC2020' , help='Dataset root directory path')

args = parser.parse_args()

CLASSES = [(  # always index 0
    'dargon fruit')]

annopath = osp.join('%s', 'Annotations', '%s.{}'.format("xml"))
imgpath  = osp.join('%s', 'JPEGImages',  '%s.{}'.format("jpg"))

def vocChecker(image_id, width, height, keep_difficult = False):
    target   = ET.parse(annopath % image_id).getroot()
    res      = []

    for obj in target.iter('object'):

        difficult = int(obj.find('difficult').text) == 1

        if not keep_difficult and difficult:
            continue

        name = obj.find('name').text.lower().strip()
        bbox = obj.find('bndbox')

        pts    = ['xmin', 'ymin', 'xmax', 'ymax']
        bndbox = []

        for i, pt in enumerate(pts):

            cur_pt = int(bbox.find(pt).text) - 1
            # scale height or width
            cur_pt = float(cur_pt) / width if i % 2 == 0 else float(cur_pt) / height

            bndbox.append(cur_pt)

        print(name)
        label_idx =  dict(zip(CLASSES, range(len(CLASSES))))[name]
        bndbox.append(label_idx)
        res += [bndbox]  # [xmin, ymin, xmax, ymax, label_ind]
        # img_id = target.find('filename').text[:-4]
    print(res)
    try :
        print(np.array(res)[:,4])
        print(np.array(res)[:,:4])
    except IndexError:
        print("\nINDEX ERROR HERE !\n")
        exit(0)
    return res  # [[xmin, ymin, xmax, ymax, label_ind], ... ]

if __name__ == '__main__' :

    i = 0

    for name in sorted(os.listdir(osp.join(args.root,'Annotations'))):
    # as we have only one annotations file per image
        i += 1

        img    = cv2.imread(imgpath  % (args.root,name.split('.')[0]))
        height, width, channels = img.shape
        print("path : {}".format(annopath % (args.root,name.split('.')[0])))
        res = vocChecker((args.root, name.split('.')[0]), height, width)
    print("Total of annotations : {}".format(i))
           

之作為學習使用不商用

ref:https://blog.csdn.net/weixin_42447868/article/details/105675158

繼續閱讀