我在Windows 10上使用Tensorflow 1.5 GPU版本。
这是代码。
targets = convert_to_onehot(labels_dir, no_of_features = num_classes)
assert targets.shape == (8,120), 'THE TARGETS SHAPE IS NOT CORRECT'
targets = tf.constant(targets, dtype = tf.float32)
Images = [] #TO STORE THE RESIZED IMAGES IN THE FORM OF LIST TO PASS IT TO tf.train.batch()
#Initally having a list of 8 images just for Testing purpose2
images = glob.glob(images_file_path)
i = 0
for my_img in images:
image = mpimg.imread(my_img)[:, :, :3]
#print (image.shape)
image = tf.constant(image, dtype = tf.float32)
Images.append(image)
i = i + 1
if i == 8:
break
batch_size = 8
images, labels = tf.train.batch([Images, targets], batch_size = batch_size, num_threads = 1, capacity = batch_size)
with tf.Session() as sess:
print (images.shape)
print (labels.shape)
当我运行上面的代码时,它打印形状(8,8,299,299,3)和(8,8,120)而不是(8,299,299,3)和(8,120)
该功能要求我将图像作为列表传递,目标是一个numpy数组。
答案 0 :(得分:0)
前8个是您的批量大小。将其更改为4,您将看到它相应地发生变化。
你做错了是预先附加所有图像。 Images.append(image)
这样你的输入已经有了一批8个图像,最重要的是tf也在批处理,这不是应该怎么做的。
要解决此问题,您必须在enqueue_many=True
tf.train.batch
tf.train.batch([Images, targets], batch_size = batch_size, num_threads = 1, capacity = batch_size, enqueue_many=True)
在此处阅读更多内容:https://www.tensorflow.org/api_docs/python/tf/train/batch
答案 1 :(得分:0)
对于tf.train.batch中的参数,添加参数enqueue_many = True
,默认为False
。这告诉tensorflow第一个维度是样本的索引。
来源:tensorflow文档。 https://www.tensorflow.org/api_docs/python/tf/train/batch