代碼
1 # -*- coding: utf-8 -*-
2 # @Time : 2018/12/1 11:06
3 # @Author : MaochengHu
4 # @Email : [email protected]
5 # @File : read_tfrecord.py
6 # @Software: PyCharm
7 import os
8 import tensorflow as tf
9 flags = tf.app.flags
10 flags.DEFINE_string('tfrecord_path', '/data1/humaoc_file/classify/data/train_tfrecord/train.record',
11 'path to tfrecord file')
12 flags.DEFINE_integer('resize_height', 800, 'resize height of image')
13 flags.DEFINE_integer('resize_width', 800, 'resize width of image')
14 FLAG = flags.FLAGS
15 slim = tf.contrib.slim
16
17 def print_data(image, resized_image, label, height, width):
18 with tf.Session() as sess:
19 init_op = tf.global_variables_initializer()
20 sess.run(init_op)
21 coord = tf.train.Coordinator()
22 threads = tf.train.start_queue_runners(coord=coord)
23 for i in range(20):
24 print("______________________image({})___________________".format(i))
25 print_image, print_resized_image, print_label, print_height, print_width = sess.run(
26 [image, resized_image, label, height, width])
27 print("resized_image shape is: ", print_resized_image.shape)
28 print("image shape is: ", print_image.shape)
29 print("image label is: ", print_label)
30 print("image height is: ", print_height)
31 print("image width is: ", print_width)
32 coord.request_stop()
33 coord.join(threads)
34
35 def reshape_same_size(image, output_height, output_width):
36 """Resize images by fixed sides.
37
38 Args:
39 image: A 3-D image `Tensor`.
40 output_height: The height of the image after preprocessing.
41 output_ The width of the image after preprocessing.
42
43 Returns:
44 resized_image: A 3-D tensor containing the resized image.
45 """
46 output_height = tf.convert_to_tensor(output_height, dtype=tf.int32)
47 output_width = tf.convert_to_tensor(output_width, dtype=tf.int32)
48
49 image = tf.expand_dims(image, 0)
50 resized_image = tf.image.resize_nearest_neighbor(
51 image, [output_height, output_width], align_corners=False)
52 resized_image = tf.squeeze(resized_image)
53 return resized_image
54
55 def read_tfrecord(tfrecord_path, num_samples=14635, num_classes=7, resize_height=800, resize_width=800):
56 keys_to_features = {
57 'image/encoded': tf.FixedLenFeature([], default_value='', dtype=tf.string, ),
58 'image/format': tf.FixedLenFeature([], default_value='jpeg', dtype=tf.string),
59 'image/class/label': tf.FixedLenFeature([], tf.int64, default_value=0),
60 'image/height': tf.FixedLenFeature([], tf.int64, default_value=0),
61 'image/width': tf.FixedLenFeature([], tf.int64, default_value=0)
62 }
63
64 items_to_handlers = {
65 'image': slim.tfexample_decoder.Image(image_key='image/encoded', format_key='image/format', channels=3),
66 'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[]),
67 'height': slim.tfexample_decoder.Tensor('image/height', shape=[]),
68 'width': slim.tfexample_decoder.Tensor('image/width', shape=[])
69 }
70 decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
71
72 labels_to_names = None
73 items_to_descriptions = {
74 'image': 'An image with shape image_shape.',
75 'label': 'A single integer between 0 and 9.'}
76
77 dataset = slim.dataset.Dataset(
78 data_sources=tfrecord_path,
79 reader=tf.TFRecordReader,
80 decoder=decoder,
81 num_samples=num_samples,
82 items_to_descriptions=None,
83 num_classes=num_classes,
84 )
85 provider = slim.dataset_data_provider.DatasetDataProvider(dataset=dataset,
86 num_readers=3,
87 shuffle=True,
88 common_queue_capacity=256,
89 common_queue_min=128,
90 seed=None)
91 image, label, height, width = provider.get(['image', 'label', 'height', 'width'])
92 resized_image = tf.squeeze(tf.image.resize_bilinear([image], size=[resize_height, resize_width]))
93 return resized_image, label, image, height, width
94
95 if __name__ == '__main__':
96 resized_image, label, image, height, width = read_tfrecord(tfrecord_path='train.record',
97 resize_height=800,
98 resize_width=800)
99 # resized_image = reshape_same_size(image, FLAG.resize_height, FLAG.resize_width)
100 # resized_image = tf.squeeze(tf.image.resize_bilinear([image], size=[FLAG.resize_height, FLAG.resize_width]))
101 print_data(image, resized_image, label, height, width)
102
103 init_g = tf.global_variables_initializer()
104 init_l = tf.local_variables_initializer()
105 with tf.Session() as sess:
106 sess.run(init_g)
107 sess.run(init_l)
108 tf.train.start_queue_runners(sess)
109 print("SDDFA")
110 trX = image.eval(session=sess)
111 trY = label.eval(session=sess)
112 print("AA")
113 print(trX.shape)