天天看點

TensorFlow學習筆記-組合訓練資料

  Tensorflow資料預處理操作:http://blog.csdn.net/lovelyaiq/article/details/78716325

  Tensorflow讀出TFRecord中的資料,然後在經過預處理操作,此時需要注意:資料還是單個,而網絡的輸入一般以Batch為機關,是以我們需要将單個的資料組合成一個Batch,做為神經網絡的輸入。

  Tensorflow提供組合訓練資料的函數有四個:tf.train.batch(),tf.train.shuffle_batch()與tf.train.batch_join、tf.train.shuffle_batch_join,這裡為什麼要用與呢?其實他們是針對兩種情況。tf.train.batch和tf.train.batch_join的差別,一般來說,單一檔案多線程,選用tf.train.batch(需要打亂樣本,有對應的tf.train.shuffle_batch);而對于多線程多檔案的情況,一般選用tf.train.batch_join來擷取樣本(打亂樣本同樣也有對應的tf.train.shuffle_batch_join使用)。下面會通過具體的例子來說明。tf.train.batch(),tf.train.shuffle_batch()這兩個函數都會生成一個隊列,隊列的入隊操作是生成單個樣例的方法,也就是經過預處理之後的圖像。

我們首先看看一下這兩個函數的定義:

def batch(tensors, batch_size, num_threads=, capacity=,
          enqueue_many=False, shapes=None, dynamic_pad=False,
          allow_smaller_final_batch=False, shared_name=None, name=None):
def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
                  num_threads=, seed=None, enqueue_many=False, shapes=None,
                  allow_smaller_final_batch=False, shared_name=None, name=None):
           

  這兩個函數的主要參數為:

  1、tensors入隊隊列,預處理後的資料和對應的标簽。

  2、batch_size:batch的大小。如果太大,則需要占用較多的記憶體資源,如果太小,那麼出隊操作可能會因為沒有資料而被阻塞,進而導緻訓練效率降低。

  3、capacity:隊列的最大容量,當隊列的長度等于容量時,Tensorflow将暫停入隊操作,而隻是等待元素出隊。當隊列個數小于容量時,Tensorflow将自動啟動入隊操作。

  4、num_threads:啟動多少個線程讀取檔案和預處理。

  5、allow_smaller_final_batch:如果設定True,則會允許最後一個Batch的大小比較小,當沒有足夠的資料輸入時。

  6、min_after_dequeue:限制出隊時隊列中元素的最小個數,如果隊列中剩餘個數太小,則随機打亂的作用就會不大。

例如:API中關于tf.train.shuffle_batch()的一個例子為:

# Creates batches of 32 images and 32 labels.
  image_batch, label_batch = tf.train.shuffle_batch(
        [single_image, single_label],
        batch_size=,
        num_threads=,
        capacity=,
        min_after_dequeue=)
           

  訓練資料被組合成Batch後,就可以進行訓練了。而詳細的設計流程參考:http://blog.csdn.net/lovelyaiq/article/details/78709826。

# -*- coding: utf-8 -*-
import tensorflow as tf
import os

# 生成整數型的屬性
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

num_shards = 
instance_per_shard = 
#
for i in range(num_shards):
    filename = 'model/data.tfrecord-%.5d-of%.5d' %(i,num_shards)
    writer = tf.python_io.TFRecordWriter(filename)
    for j in range(instance_per_shard):
        example = tf.train.Example(features=tf.train.Features(feature={
            'i':_int64_feature(i),
            'j':_int64_feature(j)
        }))
        writer.write(example.SerializeToString())
    writer.close()


tf_record_pattern = os.path.join( 'model/', 'data.tfrecord-*' )
data_files = tf.gfile.Glob( tf_record_pattern )
filename_quene = tf.train.string_input_producer(data_files,shuffle=False)

reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_quene)

features = tf.parse_single_example(serialized_example,features={
    'i': tf.FixedLenFeature([],tf.int64),
    'j': tf.FixedLenFeature( [], tf.int64),
})

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    # print(sess.run(filename))
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess = sess, coord = coord)
    for i in range():
        print(sess.run([features['i'],features['j']]))

    coord.request_stop()
    coord.join(threads)
# 輸出結果為:
[, ]
[, ]
[, ]
[, ]
[, ]
[, ]
[, ]
[, ]
[, ]
[, ]
[, ]
[, ]
[, ]
[, ]
[, ]
           

  從結果中可以看出它是通過順序讀取檔案中的内容。

  接下來,我們使用tf.train.batch來組合資料

example,lable = features['i'],features['j']
batch_size = 
capacity =  +  * batch_size

example_batch, label_batch = tf.train.batch([example,lable],batch_size = batch_size, capacity=capacity)

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    # print(sess.run(filename))
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess = sess, coord = coord)
    for i in range():
        print(sess.run([example_batch,label_batch]))

    coord.request_stop()
    coord.join(threads)
# 輸出結果為:
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
           

  從結果中可以看出它是通過順序讀取檔案中的内容。

  而當我使用tf.train.shuffle_batch時,輸出結果的順序已經被打亂。

example_batch, label_batch = tf.train.shuffle_batch([example,lable],batch_size = batch_size, capacity=capacity,min_after_dequeue =)
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
           

  當使用多個線程讀取多個檔案時,這時候就需要使用 tf.train.batch_join或tf.train.shuffle_batch_join。

nums_read = []
examples_queue = tf.RandomShuffleQueue(
    capacity= +  * batch_size,
    min_after_dequeue=,
    dtypes=tf.string )

# for i in range():
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_quene)
# nums_read.append(examples_queue.enqueue(serialized_example))

tf.train.queue_runner.add_queue_runner(
    tf.train.queue_runner.QueueRunner( examples_queue, [examples_queue.enqueue(serialized_example)]* ) )
example_serialized = examples_queue.dequeue()


images_and_labels = []
for thread_id in range(  ):
    # Parse a serialized Example proto to extract the image and metadata.
    features = tf.parse_single_example( serialized_example, features={
        'i': tf.FixedLenFeature( [], tf.int64 ),
        'j': tf.FixedLenFeature( [], tf.int64 ),
    } )
    example, lable = features['i'], features['j']

    images_and_labels.append( [example, lable] )

example_batch, label_batch = tf.train.batch_join(
    images_and_labels,
    batch_size=batch_size,
    capacity=capacity )
# 結果為:
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
[array([, , ]), array([, , ])]
           

繼續閱讀