代码参考了https://github.com/zzh8829/yolov3-tf2/blob/master/tools/voc2012.py
对zzh8829写的yolov3进行了一些理解,下面是对其中voc2012.py的一些理解。
main()函数
为什么要使用TFRecord?
TFRecord 是 TensorFlow 中的数据集存储格式。当我们将数据集整理成 TFRecord 格式后,TensorFlow 就可以高效地读取和处理这些数据集,从而帮助我们更高效地进行大规模的模型训练。
如何制作TFRecord文件呢?
TFRecord 可以理解为一系列序列化的 tf.train.Example 元素所组成的列表文件,而每一个 tf.train.Example 又由若干个 tf.train.Feature 的字典组成。形式如下:
class_map = {'aeroplane':0, 'bicycle':1……}
创建一个TFRecords文件,进行一系列操作,完成后,关闭TFRecord文件。
代码示例:
writer=tf.io.TFRecordWriter(FLAGS.output_file)
……
writer.close()
在这中间依次进行:
1.读入图片的名称,放入Image_list!注意!是名称,举个栗子,如果图片的名称是000001.jpg,那么读入的名称就是000001,是不包含后缀格式。
2.遍历image_list
A.把名称与地址拼接,生成annotation_xml,是对应的图片的xml文件的地址
B.把一串xml解析为一个xml元素
C.调用自定义函数parse_xml函数解析xml文件,返回annotation
D.调用自定义函数build_example(annotation, class_map)
E.将该 tf.train.Example 对象序列化为字符串,并通过一个预先定义的 tf.io.TFRecordWriter 写入 TFRecord 文件
writer.write(tf_example.SerializeToString())
parse_xml()函数
1.这里有个递归调用!!!child_result也是一个字典
2.遍历xml文件中的根节点,然后在根节点中调用parse_xml()函数,如果父节点没有子节点了,那么就返回当前节点的元素名称xml.tag和内容xml.text
3.如果当前子节点不是'object',那么以字典的形式写入result={}中去。
如果当前子节点是'object',且读入第一个边界框,创建一个列表,以后读入的边界框也放进这个列表。
4.annotation={'folder':JPEGImages, 'filename':'000001.jpg', ……, 'size':{'width':'1280', 'height':'1024', 'depth':'3'}, 'object':[一个列表里面放了字典'name':'bicycle', 'bndbox':{'xmin':'100', 'ymin':'100', 'xmax':'100', 'ymax':'100'}]}
build_example()函数
1.img_path存放图片地址
2.img_raw读入图片数据,类型为bytes
3.这里我还不知道为什么要用key = hashlib.sha256(img_raw).hexdigest()!!!!!!!!!
4.创建好多个空列表,主要是存放'object'中的信息,如边界框的四个位置信息、类别信息、是否为难识别物体、目标名称的utf-8编码字符串、目标名称对应的索引号等
!!!注意!!!!!xmin,xmax已经归一化,调整为0-1之间,方便后续处理
5.TFRecord 可以理解为一系列序列化的 tf.train.Example 元素所组成的列表文件,而每一个 tf.train.Example 又由若干个 tf.train.Feature 的字典组成,这个字典封装成tf.train.Features。
example = tf.train.Example(features=tf.train.Features(feature={
'image/height':tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
……
'image/filename':tf.train.Feature(bytes_list=tf.train.BytesList(value=[
annotation[
'filename'].encode(
'utf8')])),
……
'image/object/bbox/xmin':tf.train.Feature(float_list=tf.train.FloatList(value=xmin)),
}))