如何在不使用不推荐使用的功能的情况下遍历tf.dataset?

时间:2019-07-03 01:06:30

标签: python tensorflow

我正在使用tensorflow 1.14,并且对数据集有疑问。

我的代码:

my_data = [
    [0, 1],
    [2, 3],
    [4, 5],
    [6, 7]
]

slices = tf.data.Dataset.from_tensor_slices(my_data) # get dataset
it = slices.make_one_shot_iterator() # get iterator from dataset (deprecated)
next_item = it.get_next()

它表明make_one_shot_iterator已过时。.

所以我尝试了以下代码

my_data = [
    [0, 1],
    [2, 3],
    [4, 5],
    [6, 7]
]

slices = tf.data.Dataset.from_tensor_slices(my_data) # get dataset
for q in slices:
    print(sess.run(q))

我立即收到NotFoundError异常。

我的问题:遍历数据集的正确方法是什么?

3 个答案:

答案 0 :(得分:0)

import tensorflow as tf
my_data = [
    [0, 1],
    [2, 3],
    [4, 5],
    [6, 7]
]

slices = tf.data.Dataset.from_tensor_slices(my_data) # get dataset
q = slices.make_one_shot_iterator().get_next()
with tf.Session() as sess:
    for i in range(len(my_data)):
        print('-----')
        print(sess.run(q))

上面的代码产生 enter image description here

答案 1 :(得分:0)

尝试一下:

import tensorflow as tf
my_data = [
    [0, 1],
    [2, 3],
    [4, 5],
    [6, 7]
]
n = len(my_data)
slices = tf.data.Dataset.from_tensor_slices(my_data) # get dataset
iterator = slices.make_initializable_iterator()

with tf.Session() as sess:
    sess.run(iterator.initializer)
    while n>0:
        print(sess.run(iterator.get_next()))
        n-=1

如果上面仍然显示deprecation消息,请尝试以下代码:

import tensorflow as tf
tf.enable_eager_execution()

my_data = [
    [0, 1],
    [2, 3],
    [4, 5],
    [6, 7]
]
slices = tf.data.Dataset.from_tensor_slices(my_data) # get dataset
for i in slices:
    print(i.numpy())

输出:

[0 1]
[2 3]
[4 5]
[6 7]

答案 2 :(得分:0)

根据tf.data.Dataset的文档,您可以执行以下简单循环:

for element in my_dataset: 
   print(element)

如您在图像中看到的,这将返回tf.Tensor。如果您想要一个简单的元组,则可以使用:

for element in my_dataset.as_numpy_iterator(): 
   print(element)

如果数据集的每个条目都具有多个元素,则可以像通常一样使用[]索引元组的内容。

enter image description here