天天看點

tensorflow:batch and shuffle_batch

f.train.batch與tf.train.shuffle_batch的作用都是從隊列中讀取資料.

tf.train.batch

tf.train.batch() 按順序讀取隊列中的資料

隊列中的資料始終是一個有序的隊列.隊頭一直按順序補充,隊尾一直按順序出隊.

參數:

  • tensors:排列的張量或詞典.
  • batch_size:從隊列中提取新的批量大小.
  • num_threads:線程數量.若批次是不确定 num_threads > 1.
  • capacity:隊列中元素的最大數量.
  • enqueue_many:tensors中的張量是否都是一個例子.
  • shapes:每個示例的形狀.(可選項)
  • dynamic_pad:在輸入形狀中允許可變尺寸.
  • allow_smaller_final_batch:為True時,若隊列中沒有足夠的項目,則允許最終批次更小.(可選項)
  • shared_name:如果設定,則隊列将在多個會話中以給定名稱共享.(可選項)
  • name:操作的名稱.(可選項)

若enqueue_many為False,則認為tensors代表一個示例.輸入張量形狀為[x, y, z]時,則輸出張量形狀為[batch_size, x, y, z].

若enqueue_many為True,則認為tensors代表一批示例,其中第一個次元為示例的索引,并且所有成員tensors在第一維中應具有相同大小.若輸入張量形狀為[*, x, y, z],則輸出張量的形狀為[batch_size, x, y, z].

tf.train.batch()示例

#!/usr/bin/python
# coding:utf-8
import tensorflow as tf
import numpy as np

images = np.random.random([5,2])
label = np.asarray(range(0, 5))
images = tf.cast(images, tf.float32)
label = tf.cast(label, tf.int32)
# 切片
input_queue = tf.train.slice_input_producer([images, label], shuffle=False)
# 按順序讀取隊列中的資料
image_batch, label_batch = tf.train.batch(input_queue, batch_size=10, num_threads=1, capacity=64)

with tf.Session() as sess:
    # 線程的協調器
    coord = tf.train.Coordinator()
    # 開始在圖表中收集隊列運作器
    threads = tf.train.start_queue_runners(sess, coord)
    image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
    for j in range(5):
        print(image_batch_v[j]),
        print(label_batch_v[j])
    # 請求線程結束
    coord.request_stop()
    # 等待線程終止
    coord.join(threads)
           

按順序讀取隊列中的資料,輸出:

[ 0.05013787  0.53446019] 0
[ 0.91189879  0.69153142] 1
[ 0.39966023  0.86109054] 2
[ 0.85078746  0.05766034] 3
[ 0.71261722  0.60514599] 4
           

tf.train.shuffle_batch

tf.train.shuffle_batch() 将隊列中資料打亂後再讀取出來.

函數是先将隊列中資料打亂,然後再從隊列裡讀取出來,是以隊列中剩下的資料也是亂序的.

  • tensors:排列的張量或詞典.
  • batch_size:從隊列中提取新的批量大小.
  • capacity:隊列中元素的最大數量.
  • min_after_dequeue:出隊後隊列中元素的最小數量,用于確定元素的混合級别.
  • num_threads:線程數量.
  • seed:隊列内随機亂序的種子值.
  • enqueue_many:tensors中的張量是否都是一個例子.
  • shapes:每個示例的形狀.(可選項)
  • allow_smaller_final_batch:為True時,若隊列中沒有足夠的項目,則允許最終批次更小.(可選項)
  • shared_name:如果設定,則隊列将在多個會話中以給定名稱共享.(可選項)
  • name:操作的名稱.(可選項)

其他與tf.train.batch()類似.

tf.train.shuffle_batch示例

#!/usr/bin/python
# coding:utf-8
import tensorflow as tf
import numpy as np

images = np.random.random([5,2])
label = np.asarray(range(0, 5))
images = tf.cast(images, tf.float32)
label = tf.cast(label, tf.int32)
input_queue = tf.train.slice_input_producer([images, label], shuffle=False)
# 将隊列中資料打亂後再讀取出來
image_batch, label_batch = tf.train.shuffle_batch(input_queue, batch_size=10, num_threads=1, capacity=64, min_after_dequeue=1)

with tf.Session() as sess:
    # 線程的協調器
    coord = tf.train.Coordinator()
    # 開始在圖表中收集隊列運作器
    threads = tf.train.start_queue_runners(sess, coord)
    image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
    for j in range(5):
        # print(image_batch_v.shape, label_batch_v[j])
        print(image_batch_v[j]),
        print(label_batch_v[j])
    # 請求線程結束
    coord.request_stop()
    # 等待線程終止
    coord.join(threads)
           

将隊列中資料打亂後再讀取出來,輸出:

[ 0.08383977  0.75228119] 1
[ 0.03610427  0.53876138] 0
[ 0.33962703  0.47629601] 3
[ 0.21824744  0.84182823] 4
[ 0.8376292   0.52254623] 2