tf.train.batch函数正在生成批量形状的张量(8,8,299,299,3)

时间:2018-02-23 16:19:04

标签: python tensorflow

我在Windows 10上使用Tensorflow 1.5 GPU版本。

这是代码。

targets = convert_to_onehot(labels_dir, no_of_features = num_classes)
assert targets.shape == (8,120), 'THE TARGETS SHAPE IS NOT CORRECT'
targets = tf.constant(targets, dtype = tf.float32)

Images = [] #TO STORE THE RESIZED IMAGES IN THE FORM OF LIST TO PASS IT TO tf.train.batch()
#Initally having a list of 8 images just for Testing purpose2
images = glob.glob(images_file_path)
i = 0
for my_img in images:
    image = mpimg.imread(my_img)[:, :, :3]
    #print (image.shape)
    image = tf.constant(image, dtype = tf.float32)
    Images.append(image)
    i = i + 1
    if i == 8:
        break

batch_size = 8
images, labels = tf.train.batch([Images, targets], batch_size = batch_size, num_threads = 1, capacity = batch_size)
with tf.Session() as sess:
    print (images.shape)
    print (labels.shape)

当我运行上面的代码时,它打印形状(8,8,299,299,3)和(8,8,120)而不是(8,299,299,3)和(8,120)

该功能要求我将图像作为列表传递,目标是一个numpy数组。

2 个答案:

答案 0 :(得分:0)

前8个是您的批量大小。将其更改为4,您将看到它相应地发生变化。

你做错了是预先附加所有图像。 Images.append(image)

这样你的输入已经有了一批8个图像,最重要的是tf也在批处理,这不是应该怎么做的。

要解决此问题,您必须在enqueue_many=True

中传递tf.train.batch

tf.train.batch([Images, targets], batch_size = batch_size, num_threads = 1, capacity = batch_size, enqueue_many=True)

在此处阅读更多内容:https://www.tensorflow.org/api_docs/python/tf/train/batch

答案 1 :(得分:0)

对于tf.train.batch中的参数,添加参数enqueue_many = True,默认为False。这告诉tensorflow第一个维度是样本的索引。

来源:tensorflow文档。 https://www.tensorflow.org/api_docs/python/tf/train/batch