Tensorflow:无法使用Dataset API创建具有相应标签的数据集

时间:2018-08-13 14:29:14

标签: python tensorflow dataset

我的数据集包含featureslabels,例如

features, labels = (np.random.sample((5,2)), np.random.sample((5,1))) 

这意味着该数据集中有5个数据元素(有5行,每行是2维特征和1维标签)。

我使用tf.data.Dataset用以下代码创建数据集:

import tensorflow as tf
import numpy as np
features, labels = (np.random.sample((5,2)), np.random.sample((5,1))) 
print("feature : \n", features)
print("labels : \n", labels)

dataset = tf.data.Dataset.from_tensor_slices((features, labels))
iter = dataset.make_one_shot_iterator()            
x, y = iter.get_next()                                       
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())   
    print("element:\n", sess.run(x), sess.run(y))

我使用Windows 10 TF1.5,然后得到结果:

feature :
 [[0.10261779 0.28041519]  # feature0
 [0.91091857 0.95644642]   # feature1
 [0.77542043 0.49631646]   # ...
 [0.33241678 0.28630983]
 [0.39095336 0.76686785]]
labels :
 [[0.54097027]             # label0
 [0.99022349]              # label1
 [0.87510303]              # ...
 [0.07331254]
 [0.10868335]]
element:
 [0.10261779 0.28041519] [0.99022349]

创建数据集时,我希望feature0 [0.10261779 0.28041519]与label0 [0.54097027]相对应。但是使用代码,feature0 [0.10261779 0.28041519]对应于label1 [0.99022349]。顺序错误。我不知道get_next的实际工作原理。

我想知道是否有任何方法可以通过使用tensorflow Dataset API来按顺序输出功能和标签。

谢谢

1 个答案:

答案 0 :(得分:2)

问题在于,分别运行x 运行y会使迭代器前进两次。也就是说:调用sess.run(x)时,将返回features的第一个元素,并且迭代器是高级的。然后调用sess.run(y)将返回labels second 元素,因为xy都基于相同的迭代器。如果要再次调用sess.run(x),它将返回features third 元素,依此类推。

我建议您像这样重写代码,例如:

...
next_batch_op = iter.get_next()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    feature_batch, label_batch = sess.run(next_batch_op)
    print("element:\n", feature_batch, label_batch)

这只会运行迭代器一次,并使您能够访问相应的功能/标签。

作为替代,我只是尝试了以下方法,但似乎可行:

...
x, y = iter.get_next()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print("element:\n", sess.run([x, y]))

与您的代码不同的是,我们在单个x调用中同时运行yrun。但是,我发现第一种解决方案更加清晰。