天天看點

tf.data.Dataset讀取資料的幾種方式案例

 讀取方式1:一次性将序列讀入計算圖中。

import tensorflow as tf
import numpy as np

x = {"a": [i for i in range(5)],
     "b": np.random.uniform(size=(5, 2))}

#建立dataset,裡面包含 5個元素,分别為
#   {'a': 0, 'b': array([0.31102309, 0.28081324])}
#   {'a': 1, 'b': array([0.64559238, 0.9602511 ])}
#   {'a': 2, 'b': array([0.5191022 , 0.29045949])}
#   {'a': 3, 'b': array([0.80690428, 0.02572865])}
#   {'a': 4, 'b': array([0.33659348, 0.39553411])}
dataset = tf.data.Dataset.from_tensor_slices(x)

#從 dataset 中執行個體化一個iterator,該 iterator 具有 one shot iterator 特性,
#即隻能從頭到尾讀取一次
iterator = dataset.make_one_shot_iterator()

#從 iterator 中取出一個元素
one_element = iterator.get_next()

sess = tf.Session()
try:
    while True:
        print(sess.run(one_element))
except tf.errors.OutOfRangeError: #當元素被取完,再嘗試取出元素就會抛出OutOfRangeError異常
    print("end")
           

讀取方式2:使用 tf.placholder 讀取資料,一次性将資料讀入記憶體中,但是每次隻從記憶體資料中取出部分資料送入計算圖,這樣在資料較多時可以減小計算圖占用空間。

import tensorflow as tf
import numpy as np

path = r"E:\tf_project\練習\exchangeData2.txt"

data = np.loadtxt(path, delimiter=",", dtype=np.float32)
#print(data)
features = data[:, 0]
labels = data[:, 1]

features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()


sess = tf.Session()
sess.run(iterator.initializer, feed_dict={features_placeholder: features,
                                          labels_placeholder: labels})
try:
    while True:
        value = sess.run(next_element)
        print(value)
except tf.errors.OutOfRangeError:
    print("end")
           
exchangeData2.txt

6.5379,1000
6.5428,1010
6.5559,2000
6.5321,3000
6.5062,2000
6.5062,1210
6.5062,2060
6.5062,3000
6.4909,2000
6.5029,1000
6.4933,4000
6.4874,5000
6.4874,1000
6.4874,2000
6.4973,3000
6.5262,6000
6.5054,9000
6.5045,3000
6.4606,4000
6.4606,5000
6.4349,6000
6.4415,1100
6.4329,2700
6.4174,3500
6.3989,4100
6.3989,7700
6.4034,6200
6.4017,1200
6.3550,2800
6.3188,1900
6.3198,1100
6.3198,1200
           

讀取方式3:從文本中讀取資料

import tensorflow as tf
import numpy as np
tf.reset_default_graph()

#file_path = r"E:\tf_project\NMT\zh-en\train.tags.zh-en.en"
file_path = r"E:\tf_project\練習\word_src.txt"
dataset = tf.data.TextLineDataset(file_path)
dataset = dataset.map(lambda string: tf.string_split([string]).values)
dataset = dataset.map(lambda x:(x, tf.size(x)))
iterator = dataset.make_one_shot_iterator()
sess = tf.Session()
try:
    i = 0
    while i < 10:
        print(sess.run(iterator.get_next()))
        i += 1
except tf.errors.OutOfRangeError:
    print("end")  
           
word_src.txt

it is the time
           

繼續閱讀