我正在尝试使用TF documentation中的示例作为tf.data.Dataset.window
的示例,但是文档中的示例失败了。
源自文档的代码:
import tensorflow as tf
ds = tf.data.Dataset.range(7).window(2)
next_element = ds.make_one_shot_iterator().get_next()
with tf.Session() as sess:
print(sess.run(next_element))
产生此错误(删除跟踪):
TypeError: Can not convert a _VariantDataset into a Tensor or Operation.
During handling of the above exception, another exception occurred:
TypeError: Fetch argument <_VariantDataset shapes: (), types: tf.int64> has invalid type <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>, must be a string or Tensor. (Can not convert a _VariantDataset into a Tensor or Operation.)
所以iterator.get_next()
返回的是VariantDataset
,而不是通常的张量。
TF版本:1.13.1
答案 0 :(得分:1)
Window会生成类似结构的数据集,在您的情况下该结构应该返回对{1,2}。不知道如何正确使用它或为什么存在它,但是设法使它像这样工作: 将tensorflow导入为tf
import tensorflow as tf
nxt = (tf.data.Dataset
.range(7)
.window(2, 1, 2, True)
.flat_map(lambda x: x.batch(2))
.make_one_shot_iterator()
.get_next()
)
with tf.Session() as sess:
print(sess.run(nxt))
答案 1 :(得分:0)
@ y.selivonchyk提供了有助于我理解这一正确答案。我添加了第二个示例,该示例使用滑动窗口来帮助向偶然发现此问题的人们阐明正确的方法。请特别注意,窗口大小和批处理大小相等。
import tensorflow as tf
window_size = 3
ds = tf.data.Dataset.range(20)
ds = ds.window(size=window_size, shift=1, stride=1, drop_remainder=False)
ds = ds.flat_map(lambda x: x.batch(window_size))
next_sample = ds.make_one_shot_iterator().get_next()
with tf.Session() as sess:
while True:
try:
print(sess.run(next_sample))
except tf.errors.OutOfRangeError:
print('EOF')
break
[0 1 2]
[1 2 3]
[2 3 4]
[3 4 5]
[4 5 6]
[5 6 7]
[6 7 8]
[7 8 9]
[ 8 9 10]
[ 9 10 11]
[10 11 12]
[11 12 13]
[12 13 14]
[13 14 15]
[14 15 16]
[15 16 17]
[16 17 18]
[17 18 19]
[18 19]
[19]
EOF