Tensorflow資料預處理操作:http://blog.csdn.net/lovelyaiq/article/details/78716325
Tensorflow讀出TFRecord中的資料,然後在經過預處理操作,此時需要注意:資料還是單個,而網絡的輸入一般以Batch為機關,是以我們需要将單個的資料組合成一個Batch,做為神經網絡的輸入。
Tensorflow提供組合訓練資料的函數有四個:tf.train.batch(),tf.train.shuffle_batch()與tf.train.batch_join、tf.train.shuffle_batch_join,這裡為什麼要用與呢?其實他們是針對兩種情況。tf.train.batch和tf.train.batch_join的差別,一般來說,單一檔案多線程,選用tf.train.batch(需要打亂樣本,有對應的tf.train.shuffle_batch);而對于多線程多檔案的情況,一般選用tf.train.batch_join來擷取樣本(打亂樣本同樣也有對應的tf.train.shuffle_batch_join使用)。下面會通過具體的例子來說明。tf.train.batch(),tf.train.shuffle_batch()這兩個函數都會生成一個隊列,隊列的入隊操作是生成單個樣例的方法,也就是經過預處理之後的圖像。
我們首先看看一下這兩個函數的定義:
def batch(tensors, batch_size, num_threads=, capacity=,
enqueue_many=False, shapes=None, dynamic_pad=False,
allow_smaller_final_batch=False, shared_name=None, name=None):
def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
num_threads=, seed=None, enqueue_many=False, shapes=None,
allow_smaller_final_batch=False, shared_name=None, name=None):
這兩個函數的主要參數為:
1、tensors入隊隊列,預處理後的資料和對應的标簽。
2、batch_size:batch的大小。如果太大,則需要占用較多的記憶體資源,如果太小,那麼出隊操作可能會因為沒有資料而被阻塞,進而導緻訓練效率降低。
3、capacity:隊列的最大容量,當隊列的長度等于容量時,Tensorflow将暫停入隊操作,而隻是等待元素出隊。當隊列個數小于容量時,Tensorflow将自動啟動入隊操作。
4、num_threads:啟動多少個線程讀取檔案和預處理。
5、allow_smaller_final_batch:如果設定True,則會允許最後一個Batch的大小比較小,當沒有足夠的資料輸入時。
6、min_after_dequeue:限制出隊時隊列中元素的最小個數,如果隊列中剩餘個數太小,則随機打亂的作用就會不大。
例如:API中關于tf.train.shuffle_batch()的一個例子為:
# Creates batches of 32 images and 32 labels.
image_batch, label_batch = tf.train.shuffle_batch(
[single_image, single_label],
batch_size=,
num_threads=,
capacity=,
min_after_dequeue=)
訓練資料被組合成Batch後,就可以進行訓練了。而詳細的設計流程參考:http://blog.csdn.net/lovelyaiq/article/details/78709826。
# -*- coding: utf-8 -*-
import tensorflow as tf
import os
# 生成整數型的屬性
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
num_shards =
instance_per_shard =
#
for i in range(num_shards):
filename = 'model/data.tfrecord-%.5d-of%.5d' %(i,num_shards)
writer = tf.python_io.TFRecordWriter(filename)
for j in range(instance_per_shard):
example = tf.train.Example(features=tf.train.Features(feature={
'i':_int64_feature(i),
'j':_int64_feature(j)
}))
writer.write(example.SerializeToString())
writer.close()
tf_record_pattern = os.path.join( 'model/', 'data.tfrecord-*' )
data_files = tf.gfile.Glob( tf_record_pattern )
filename_quene = tf.train.string_input_producer(data_files,shuffle=False)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_quene)
features = tf.parse_single_example(serialized_example,features={
'i': tf.FixedLenFeature([],tf.int64),
'j': tf.FixedLenFeature( [], tf.int64),
})
with tf.Session() as sess:
tf.global_variables_initializer().run()
# print(sess.run(filename))
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess = sess, coord = coord)
for i in range():
print(sess.run([features['i'],features['j']]))
coord.request_stop()
coord.join(threads)
# 輸出結果為:
[, ]
[, ]
[, ]
[, ]
[, ]
[, ]
[, ]
[, ]
[, ]
[, ]
[, ]
[, ]
[, ]
[, ]
[, ]
從結果中可以看出它是通過順序讀取檔案中的内容。
接下來,我們使用tf.train.batch來組合資料
example,lable = features['i'],features['j']
batch_size =
capacity = + * batch_size
example_batch, label_batch = tf.train.batch([example,lable],batch_size = batch_size, capacity=capacity)
with tf.Session() as sess:
tf.global_variables_initializer().run()
# print(sess.run(filename))
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess = sess, coord = coord)
for i in range():
print(sess.run([example_batch,label_batch]))
coord.request_stop()
coord.join(threads)
# 輸出結果為:
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
從結果中可以看出它是通過順序讀取檔案中的内容。
而當我使用tf.train.shuffle_batch時,輸出結果的順序已經被打亂。
example_batch, label_batch = tf.train.shuffle_batch([example,lable],batch_size = batch_size, capacity=capacity,min_after_dequeue =)
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
當使用多個線程讀取多個檔案時,這時候就需要使用 tf.train.batch_join或tf.train.shuffle_batch_join。
nums_read = []
examples_queue = tf.RandomShuffleQueue(
capacity= + * batch_size,
min_after_dequeue=,
dtypes=tf.string )
# for i in range():
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_quene)
# nums_read.append(examples_queue.enqueue(serialized_example))
tf.train.queue_runner.add_queue_runner(
tf.train.queue_runner.QueueRunner( examples_queue, [examples_queue.enqueue(serialized_example)]* ) )
example_serialized = examples_queue.dequeue()
images_and_labels = []
for thread_id in range( ):
# Parse a serialized Example proto to extract the image and metadata.
features = tf.parse_single_example( serialized_example, features={
'i': tf.FixedLenFeature( [], tf.int64 ),
'j': tf.FixedLenFeature( [], tf.int64 ),
} )
example, lable = features['i'], features['j']
images_and_labels.append( [example, lable] )
example_batch, label_batch = tf.train.batch_join(
images_and_labels,
batch_size=batch_size,
capacity=capacity )
# 結果為:
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]