Tensorflow数据集api问题

时间:2018-03-10 08:51:46

标签: tensorflow tensorflow-datasets

这是我的代码段

features=np.array([1,2,3,4,5,6,7],dtype=float)
labels=np.array([1,2,3,4,5,6,7],dtype=float)
training_data=(features,labels)

train_dataset=tf.data.Dataset.from_tensor_slices(training_data)
train_dataset=train_dataset.batch(1)

iter=train_dataset.make_one_shot_iterator()
batch=iter.get_next()

with tf.Session() as sess:
    x,y=batch
    a=x.eval()
    b=y.eval()   
    print(a,"---------->",b)

输出:     [1] ---------> [2]

预期输出[1] ---------> [1]

我已经花了6个小时,当我遇到这个问题时,我正在训练LSTM模型。我错过了什么?

1 个答案:

答案 0 :(得分:1)

问题在于,在将batch分解为x, y后,您没有得到两个简单的张量,而是得到两个迭代器:

In [15]: batch
Out[15]:
(<tf.Tensor 'IteratorGetNext_1:0' shape=(?,) dtype=float64>,
 <tf.Tensor 'IteratorGetNext_1:1' shape=(?,) dtype=float64>)

因此,x.eval()将迭代器增加1,y.eval()再次增加迭代器,使您看到值(1, 2)

相反,这样做只运行一次迭代器:

with tf.Session() as sess:
    a, b = sess.run(batch)
    print(a,"---------->",b)

你应该看到预期的结果。