官方文檔連結:https://tensorflow.google.cn/versions/r1.8/api_docs/python/tf/train/batch
tf.train.batch(
tensors,
batch_size,
num_threads=1,
capacity=32,
enqueue_many=False,
shapes=None,
dynamic_pad=False,
allow_smaller_final_batch=False,
shared_name=None,
name=None
)
函數功能:利用一個tensor的清單或字典來擷取一個batch資料
參數介紹:
- tensors:一個清單或字典的tensor用來進行入隊
- batch_size:設定每次從隊列中擷取出隊資料的數量
- num_threads:用來控制入隊tensors線程的數量,如果num_threads大于1,則batch操作将是非确定性的,輸出的batch可能會亂序
- capacity:一個整數,用來設定隊列中元素的最大數量
- enqueue_many:在tensors中的tensor是否是單個樣本
- shapes:可選,每個樣本的shape,預設是tensors的shape
- dynamic_pad:Boolean值.允許輸入變量的shape,出隊後會自動填補次元,來保持與batch内的shapes相同
- allow_samller_final_batch:可選,Boolean值,如果為True隊列中的樣本數量小于batch_size時,出隊的數量會以最終遺留下來的樣本進行出隊,如果為Flalse,小于batch_size的樣本不會做出隊處理
- shared_name:可選,通過設定該參數,可以對多個會話共享隊列
- name:可選,操作的名字
從數組中每次擷取一個batch_size的資料
import numpy as np
import tensorflow as tf
def next_batch():
datasets = np.asarray(range(0,20))
input_queue = tf.train.slice_input_producer([datasets],shuffle=False,num_epochs=1)
data_batchs = tf.train.batch(input_queue,batch_size=5,num_threads=1,
capacity=20,allow_smaller_final_batch=False)
return data_batchs
if __name__ == "__main__":
data_batchs = next_batch()
sess = tf.Session()
sess.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess,coord)
try:
while not coord.should_stop():
data = sess.run([data_batchs])
print(data)
except tf.errors.OutOfRangeError:
print("complete")
finally:
coord.request_stop()
coord.join(threads)
sess.close()
![](https://img.laitimes.com/img/_0nNw4CM6IyYiwiM6ICdiwiIwczLcVmds92czlGZvwVP9EUTDZ0aRJkSwk0LcxGbpZ2LcBDM08CXlpXazRnbvZ2LcRlMMVDT2EWNvwFdu9mZvwVMVRkTzUFVPVTS6hFMG1mYw50MMBjVtJWd0ckW65UbM5WOHJWa5kHT20ESjBjUIF2LcRHelR3LcJzLctmch1mclRXY39jM3gDOxQTNxIDOxATM4EDMy8CX0Vmbu4GZzNmLn9Gbi1yZtl2Lc9CX6MHc0RHaiojIsJye.jpg)
注意:tf.train.batch這個函數的實作是使用queue,queue的QueueRunner被添加到目前計算圖的"QUEUE_RUNNER"集合中,所在使用初始化器的時候,需要使用tf.initialize_local_variables(),如果使用tf.global_varialbes_initialize()時,會報: Attempting to use uninitialized value