在TensorFlow中,我们说我们的训练数据xs
采用numpy NHCW格式。我想在Tensorflow中对来自xs
的批次进行抽样,我做了
xs = np.reshape(range(32), [4,2,2,2])
tensor_list = [tf.convert_to_tensors(x) for x in xs]
#x_tensor = tf.convert_to_tensors(xs) # tried this version too
x_batch = tf.train.shuffle_batch(tensor_list, batch_size=3, capacity=50, min_after_dequeue=10)
此代码不是从tensor_list
采样,而是返回一个列表,其长度与数据点的数量相同(在本例中为4),每个列表元素都是张量,其中第一个维度为{{ 1}}(在这种情况下为3)。个人直觉结果是batch_size
是一个4维张量,第一维的值是x_batch
,内容是随机抽样的。然后,每当我们拨打batch_size
时,我们都会有不同的批次。
请让我知道我做错了什么。
答案 0 :(得分:0)
Haven没有想到这一点,但发现了一个不同的解决方案如下。
xs = np.reshape(range(32), [4,2,2,2])
x_tensor = tf.convert_to_tensor(xs)
dataset = tf.data.Dataset.from_tensor_slices(x_tensor)
dataset = dataset.batch(2)
dataset = dataset.shuffle(buffer_size=10000)
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer)
while True:
try:
next_element = iterator.get_next()
except tf.errors.OutOfRangeError:
print("End of dataset") # ==> "End of dataset"
然后,每次调用iterator.get_next()
时,这将输出一个大小为2的随机批次。