Tensorflow - 如何批量处理数据集

时间:2016-12-22 08:10:52

标签: tensorflow

我正在进行数字识别的卷积神经网络。我想训练一个图像数据集,但我不知道如何"批量" 训练数据。

我得到两个存储train_image和train_label的数组:

print train_image.shape
# (73257, 1024)
# where I have 73257 images with size 32x32=1024

print train_label.shape
# (73257, 10)
# Digit '1' has label 1, '9' has label 9 and '0' has label 10

现在,我想批量训练数据批量大小= 50

    sess.run(tf.initialize_all_variables())
    train_image_batch, train_label_batch = tf.train.shuffle_batch([train_image,
       train_label, batch_size = 50, capacity = 50000, min_after_dequeue = 10000)

当我打印出train_image_batch

print train_image_batch
# Tensor("shuffle_batch:0", shape=(50, 73257, 1024), dtype=unit8)

我希望形状应为(50, 1024)

我在这里做错了吗?

1 个答案:

答案 0 :(得分:1)

shuffle_batch默认情况下需要单个样本。通过enqueue_many=True强制接受多个样本。请参阅doc

train_image_batch, train_label_batch = tf.train.shuffle_batch(
    [train_image, train_label], batch_size = 50, enqueue_many=True, capacity = 50000, min_after_dequeue = 10000)

print(train_image_batch.shape)

Output:
(50, 1024)