天天看點

tf.train.string_input_producer()和tf.train.slice_input_producer()import tensorflow as tf path_list=['

Tensorflow一共提供了3種讀取資料的方法:

供給資料(Feeding): 在TensorFlow程式運作的每一步, 讓Python代碼來供給資料,比如說用PIL和numpy處理資料然後喂入神經網絡。

從檔案讀取資料: 在TensorFlow圖的起始, 讓一個輸入管線從檔案中讀取資料,這就是這篇文将要講的内容。

預加載資料: 在TensorFlow圖中定義常量或變量來儲存所有資料(僅适用于資料量比較小的情況)。

對于大的資料集很難用numpy數組儲存,是以這裡介紹一下Tensorflow讀取很大資料集的方法:string_input_producer()和slice_input_producer()。

他們兩者差別可以簡單了解為:string_input_producer每次放出一個檔案名。slice_input_producer可以既可以同時放出檔案名和它對應的label,也可以隻放出一個檔案名。而在實際應用代碼的時候也隻是讀取檔案的方式不一樣,其他大緻相同。

string_input_producer加載圖檔的reader是reader = tf.WholeFileReader() key,value = reader.read(path_queue)其中key是檔案名,value是byte類型的檔案流二進制,一般需要解碼(decode)一下才能變成數組,然後進行reshape操作。

slice_input_producer加載圖檔的reader使用tf.read_file(filename)直接讀取。記得圖檔需要解碼和resize成數組,才可以放入記憶體隊列file_queue中等待調用。

import tensorflow as tf
path_list=['A.png','B.png','C.png']
img_path=tf.convert_to_tensor(path_list,dtype=tf.string)#将list轉化張量tensor

image=tf.train.string_input_producer(img_path,num_epochs=1)#放入檔案名隊列中,epoch是1

def load_img(path_queue):
    #建立一個隊列讀取器,然後解碼成數組,與slice的不同之處,重要!!!!!!!!!
    reader=tf.WholeFileReader()
    key,value=reader.read(path_queue)

    img=tf.image.convert_image_dtype(tf.image.decode_png(value,channels=3),tf.float32)#将圖檔decode成3通道的數組
    img=tf.image.resize_images(img,size=(224,224))
    return img

img=load_img(image)
print(img.shape)
#可以看出string進行處理的時候隻處理了圖檔本身,對标簽并沒有處理。将圖檔放入記憶體隊列,因為abtch_size=1,是以一次放入一張供讀取。但是系統還是“停滞”狀态。
image_batch=tf.train.batch([img],batch_size=1)

with tf.Session() as sess:
    tf.local_variables_initializer().run()
    tf.global_variables_initializer().run()
    coord=tf.train.Coordinator()
    #tf.train.start_queue_runners()函數才會啟動填充隊列的線程,系統不再“停滞”,此後計算單元就可以拿到資料并進行計算
    thread=tf.train.start_queue_runners(sess=sess,coord=coord)
    try:
        while not coord.should_stop():
            imgs=sess.run(image_batch)
            print(imgs.shape)
    #當檔案隊列讀到末尾的時候,抛出異常
    except tf.errors.OutOfRangeError:
        print('done')
    finally:
        coord.request_stop()#将讀取檔案的線程關閉
    coord.join(thread)#将讀取檔案的線程加入到主線程中(雖然說已經關閉過)
           
import tensorflow as tf

path_list=['A.png','B.png','C.png']
#加入了标簽,在使用的時候可以直接對應标簽取出資料
label=[0,1,2]
#轉換成張量tensor類型
img_path=tf.convert_to_tensor(path_list,dtype=tf.string)
label=tf.convert_to_tensor(label,dtype=tf.int32)

#傳回了一個包含路徑和标簽的清單,并将檔案名和對應的标簽放入檔案名對列中,等待系統調用
image=tf.train.slice_input_producer([img_path,label],shuffle=True,num_epochs=1)#shuffle=Flase表示不打亂,當為True的時候打亂順序放入檔案名隊列
labels=image[1]

def load_image(path_queue):
    #讀取檔案,這點與string_input_producer不一樣!!!!!
    file_contents=tf.read_file(image[0])
    img=tf.image.convert_image_dtype(tf.image.decode_png(file_contents,channels=3),tf.float32)

    img=tf.image.resize_images(img,size=(228,228))
    return img

img=load_image(image)
print(img.shape)
#設定one_hot編碼,并将labels規定為3種,在前向傳播的時候預設會将結果的shape變為batch_size*3,進而達到分類的情況,這一步在使用标簽的時候很重要
labels=tf.one_hot(labels,3)
img_batch,label_batch=tf.train.batch([img,labels],batch_size=1)

with tf.Session() as sess:
    #initializer for num_epochs
    tf.local_variables_initializer().run()
    coord=tf.train.Coordinator()
    thread=tf.train.start_queue_runners(sess=sess,coord=coord)
    try:
        while not coord.should_stop():
            imgs,label=sess.run([img_batch,label_batch])
            print(imgs.shape)
            print(label)
    except tf.errors.OutOfRangeError:
        print('Done')
    finally:
        coord.request_stop()
    coord.join(thread)