为什么tf.train.batch会在TensorFlow中为输出添加额外的维度?

时间:2017-12-06 18:30:25

标签: tensorflow

在TensorFlow中,我们说我们的训练数据xs采用numpy NHCW格式。我想在Tensorflow中对来自xs的批次进行抽样,我做了

xs = np.reshape(range(32), [4,2,2,2])    
tensor_list = [tf.convert_to_tensors(x) for x in xs]
#x_tensor = tf.convert_to_tensors(xs) # tried this version too
x_batch = tf.train.shuffle_batch(tensor_list, batch_size=3, capacity=50, min_after_dequeue=10)

此代码不是从tensor_list采样,而是返回一个列表,其长度与数据点的数量相同(在本例中为4),每个列表元素都是张量,其中第一个维度为{{ 1}}(在这种情况下为3)。个人直觉结果是batch_size是一个4维张量,第一维的值是x_batch,内容是随机抽样的。然后,每当我们拨打batch_size时,我们都会有不同的批次。

请让我知道我做错了什么。

1 个答案:

答案 0 :(得分:0)

Haven没有想到这一点,但发现了一个不同的解决方案如下。

xs = np.reshape(range(32), [4,2,2,2])    
x_tensor = tf.convert_to_tensor(xs)
dataset = tf.data.Dataset.from_tensor_slices(x_tensor)
dataset = dataset.batch(2)
dataset = dataset.shuffle(buffer_size=10000)
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer)

while True:
  try:
    next_element = iterator.get_next()
  except tf.errors.OutOfRangeError:
    print("End of dataset")  # ==> "End of dataset"

然后,每次调用iterator.get_next()时,这将输出一个大小为2的随机批次。