天天看点

Win10搭建mmdetection2.6环境并训练模型(二)一、数据集准备二、训练前修改网络配置

Win10搭建mmdetection2.6环境并训练模型(二)

  • 一、数据集准备
  • 二、训练前修改网络配置
    • 1.网络层修改
    • 2.配置文件修改
    • 3.开始训练

一、数据集准备

这里以cascade_rcnn为例,首先是对图片打标签,打标签的话是用labelme软件,生成json文件,软件使用方式如图。

Win10搭建mmdetection2.6环境并训练模型(二)一、数据集准备二、训练前修改网络配置

我这里用的是labeimg软件,生成xml文件,然后xml文件需要转换成json文件,利用python脚本很容易转换,代码贴在下面。各取所需

#labelimg2labelme.py
import xml.etree.ElementTree as ET  # 读取xml。
import os
import json

def parse_rec(rootPath, file):
    pathFile = os.path.join(rootPath, file)
    root = ET.parse(pathFile)  # 解析读取xml函数
    floder = root.find('folder').text
    filename = root.find('filename').text
    path = root.find('path').text
    print(floder, filename, path)
    sz = root.find('size')
    width = int(sz[0].text)
    height = int(sz[1].text)
    print(width, height)
    data = {}
    data['imagePath'] = filename
    data['flags'] = {}
    data['imageWidth'] = width
    data['imageHeight'] = height
    data['imageData'] = None
    data['version'] = "4.5.6"
    data["shapes"] = []
    for child in root.findall('object'):  # 找到图片中的所有框
        sub = child.find('bndbox')  # 找到框的标注值并进行读取
        xmin = float(sub[0].text)
        ymin = float(sub[1].text)
        xmax = float(sub[2].text)
        ymax = float(sub[3].text)
        points = [[xmin, ymin], [xmax, ymax]]
        itemData = {'points': []}
        itemData['points'].extend(points)
        name = child.find("name").text
        itemData["flag"] = {}
        itemData["group_id"] = None
        itemData["shape_type"] = "rectangle"
        itemData["label"] = name
        data["shapes"].append(itemData)

    (filename, extension) = os.path.splitext(file)
    jsonName = ".".join([filename, "json"])
    print(rootPath, jsonName)
    jsonPath = os.path.join(rootPath, jsonName)
    with open(jsonPath, "w") as f:
        json.dump(data, f)
    print("加载入文件完成...")


if __name__ == '__main__':
    path = "这里写入图片路径"
    for root, dirs, files in os.walk(path):
        for file in files:
            if file.endswith(".xml"):
                parse_rec(root, file)
           

数据集最终放在data/coco文件夹下:

Win10搭建mmdetection2.6环境并训练模型(二)一、数据集准备二、训练前修改网络配置

二、训练前修改网络配置

1.网络层修改

由于原始的网络是针对公开数据集的,这里我们需要对自己的项目修改,代码如下:

import torch
pretrained_weights  = torch.load('checkpoints/cascade_rcnn_r50_fpn_1x_coco_20200316-3dc56deb.pth')#这里换成你自己的预训练模型

num_class = 3#这里是自己的类别数量
pretrained_weights['state_dict']['roi_head.bbox_head.0.fc_cls.weight'].resize_(num_class+1, 1024)
pretrained_weights['state_dict']['roi_head.bbox_head.0.fc_cls.bias'].resize_(num_class+1)
pretrained_weights['state_dict']['roi_head.bbox_head.1.fc_cls.weight'].resize_(num_class+1, 1024)
pretrained_weights['state_dict']['roi_head.bbox_head.1.fc_cls.bias'].resize_(num_class+1)
pretrained_weights['state_dict']['roi_head.bbox_head.2.fc_cls.weight'].resize_(num_class+1, 1024)
pretrained_weights['state_dict']['roi_head.bbox_head.2.fc_cls.bias'].resize_(num_class+1)

torch.save(pretrained_weights, "cascade_rcnn_r50_fpn_1x_%d.pth"%num_class)#这里是修改后的模型存放的地方
           

一共要注意这三个地方,我在上面标记出来了

标题文本样式列表链接目录代码片表格注脚注释自定义列表LaTeX 数学公式插入甘特图插入UML图插入Mermaid流程图插入Flowchart流程图插入类图快捷键

标题复制

2.配置文件修改

修改configs/base/schedules/schedule_1x.py 中

total_epochs总训练次数和 lr学习率

# optimizer
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(
    policy='step',
    warmup='linear',
    warmup_iters=500,
    warmup_ratio=0.001,
    step=[90, 95])
total_epochs = 100
           

修改configs/base/default_runtime.py

load_from=下面生成的新的权重文件

checkpoint_config = dict(interval=5)训练多少次保存一次权重

log_config = dict(

interval=11, #填写你的val2014文件夹图片数目

hooks=[

dict(type=‘TextLoggerHook’),

# dict(type=‘TensorboardLoggerHook’)

])

checkpoint_config = dict(interval=10)
# yapf:disable
log_config = dict(
    interval=73,
    hooks=[
        dict(type='TextLoggerHook'),
        # dict(type='TensorboardLoggerHook')
    ])
# yapf:enable
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = 'cascade_rcnn_r50_fpn_1x_3.pth'
resume_from = None
workflow = [('train', 1)]

           

修改mmdet/datasets/coco.py 将原来的80类替换成自己的类别。

@DATASETS.register_module()
class CocoDataset(CustomDataset):

    CLASSES = ('wallet','phone','bag')#这边是自己的类别
           

3.开始训练

在power shell中输入如下命令

python tools/train.py configs/cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.py
           

继续阅读