我无法使此代码正常工作,我在哪里错了?
dataset = tf.data.Dataset.from_tensors(np.arange(8))
dataset = dataset.apply(tf.contrib.data.sliding_window_batch(window_size=4))
iterator = dataset.make_one_shot_iterator()
element = iterator.get_next()
with tf.Session() as sess:
while True:
try:
print(sess.run(element))
except tf.errors.OutOfRangeError:
print('end')
break
我本来期望[0,1,2,3],[1,2,3,4],...
,但我什么也没得到。
编辑:
如果在print(dataset)
之前做apply
,我得到<TensorDataset shapes: (8,), types: tf.int64>
,在apply
之后,我得到<_SlideDataset shapes: (?, 8), types: tf.int64>
,这不是我期望的:{ {1}}是_SlideDataset
吗?
答案 0 :(得分:1)
将代码从from_tensors
更改为from_tensor_slices
。请参见下面的代码更新:
import tensorflow as tf
import numpy as np
dataset = tf.data.Dataset.from_tensor_slices((np.arange(8)))
dataset = dataset.apply(tf.contrib.data.sliding_window_batch(window_size=4))
iterator = dataset.make_one_shot_iterator()
element = iterator.get_next()
with tf.Session() as sess:
while True:
try:
print(sess.run(element))
except tf.errors.OutOfRangeError:
print('end')
break
[0 1 2 3]
[1 2 3 4]
[2 3 4 5]
[3 4 5 6]
[4 5 6 7]
end