天天看点

voc数据集_如何将VOC数据集转换成Tfrecord文件

voc数据集_如何将VOC数据集转换成Tfrecord文件

代码参考了https://github.com/zzh8829/yolov3-tf2/blob/master/tools/voc2012.py

对zzh8829写的yolov3进行了一些理解,下面是对其中voc2012.py的一些理解。

main()函数

为什么要使用TFRecord?

TFRecord 是 TensorFlow 中的数据集存储格式。当我们将数据集整理成 TFRecord 格式后,TensorFlow 就可以高效地读取和处理这些数据集,从而帮助我们更高效地进行大规模的模型训练。

如何制作TFRecord文件呢?

TFRecord 可以理解为一系列序列化的 tf.train.Example 元素所组成的列表文件,而每一个 tf.train.Example 又由若干个 tf.train.Feature 的字典组成。形式如下:

class_map = {'aeroplane':0, 'bicycle':1……}

创建一个TFRecords文件,进行一系列操作,完成后,关闭TFRecord文件。

代码示例:

writer=tf.io.TFRecordWriter(FLAGS.output_file)

……

writer.close()

在这中间依次进行:

1.读入图片的名称,放入Image_list!注意!是名称,举个栗子,如果图片的名称是000001.jpg,那么读入的名称就是000001,是不包含后缀格式。

2.遍历image_list

A.把名称与地址拼接,生成annotation_xml,是对应的图片的xml文件的地址

B.把一串xml解析为一个xml元素

C.调用自定义函数parse_xml函数解析xml文件,返回annotation

D.调用自定义函数build_example(annotation, class_map)

E.将该 tf.train.Example 对象序列化为字符串,并通过一个预先定义的 tf.io.TFRecordWriter 写入 TFRecord 文件

writer.write(tf_example.SerializeToString())

parse_xml()函数

1.这里有个递归调用!!!child_result也是一个字典

2.遍历xml文件中的根节点,然后在根节点中调用parse_xml()函数,如果父节点没有子节点了,那么就返回当前节点的元素名称xml.tag和内容xml.text

3.如果当前子节点不是'object',那么以字典的形式写入result={}中去。

如果当前子节点是'object',且读入第一个边界框,创建一个列表,以后读入的边界框也放进这个列表。

4.annotation={'folder':JPEGImages, 'filename':'000001.jpg', ……, 'size':{'width':'1280', 'height':'1024', 'depth':'3'}, 'object':[一个列表里面放了字典'name':'bicycle', 'bndbox':{'xmin':'100', 'ymin':'100', 'xmax':'100', 'ymax':'100'}]}

build_example()函数

1.img_path存放图片地址

2.img_raw读入图片数据,类型为bytes

3.这里我还不知道为什么要用key = hashlib.sha256(img_raw).hexdigest()!!!!!!!!!

4.创建好多个空列表,主要是存放'object'中的信息,如边界框的四个位置信息、类别信息、是否为难识别物体、目标名称的utf-8编码字符串、目标名称对应的索引号等

!!!注意!!!!!xmin,xmax已经归一化,调整为0-1之间,方便后续处理

5.TFRecord 可以理解为一系列序列化的 tf.train.Example 元素所组成的列表文件,而每一个 tf.train.Example 又由若干个 tf.train.Feature 的字典组成,这个字典封装成tf.train.Features。

example = tf.train.Example(features=tf.train.Features(feature={

'image/height'

:tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),

……

'image/filename'

:tf.train.Feature(bytes_list=tf.train.BytesList(value=[

annotation[

'filename'

].encode(

'utf8'

)])),

……

'image/object/bbox/xmin'

:tf.train.Feature(float_list=tf.train.FloatList(value=xmin)),

}))