我正在使用tensorflow 1.14,并且对数据集有疑问。
我的代码:
my_data = [
[0, 1],
[2, 3],
[4, 5],
[6, 7]
]
slices = tf.data.Dataset.from_tensor_slices(my_data) # get dataset
it = slices.make_one_shot_iterator() # get iterator from dataset (deprecated)
next_item = it.get_next()
它表明make_one_shot_iterator已过时。.
所以我尝试了以下代码
my_data = [
[0, 1],
[2, 3],
[4, 5],
[6, 7]
]
slices = tf.data.Dataset.from_tensor_slices(my_data) # get dataset
for q in slices:
print(sess.run(q))
我立即收到NotFoundError异常。
我的问题:遍历数据集的正确方法是什么?
答案 0 :(得分:0)
import tensorflow as tf
my_data = [
[0, 1],
[2, 3],
[4, 5],
[6, 7]
]
slices = tf.data.Dataset.from_tensor_slices(my_data) # get dataset
q = slices.make_one_shot_iterator().get_next()
with tf.Session() as sess:
for i in range(len(my_data)):
print('-----')
print(sess.run(q))
答案 1 :(得分:0)
尝试一下:
import tensorflow as tf
my_data = [
[0, 1],
[2, 3],
[4, 5],
[6, 7]
]
n = len(my_data)
slices = tf.data.Dataset.from_tensor_slices(my_data) # get dataset
iterator = slices.make_initializable_iterator()
with tf.Session() as sess:
sess.run(iterator.initializer)
while n>0:
print(sess.run(iterator.get_next()))
n-=1
如果上面仍然显示deprecation
消息,请尝试以下代码:
import tensorflow as tf
tf.enable_eager_execution()
my_data = [
[0, 1],
[2, 3],
[4, 5],
[6, 7]
]
slices = tf.data.Dataset.from_tensor_slices(my_data) # get dataset
for i in slices:
print(i.numpy())
输出:
[0 1]
[2 3]
[4 5]
[6 7]
答案 2 :(得分:0)
根据tf.data.Dataset的文档,您可以执行以下简单循环:
for element in my_dataset:
print(element)
如您在图像中看到的,这将返回tf.Tensor
。如果您想要一个简单的元组,则可以使用:
for element in my_dataset.as_numpy_iterator():
print(element)
如果数据集的每个条目都具有多个元素,则可以像通常一样使用[]
索引元组的内容。