我正在寻找将tf.train.batch
与enqueue_many=True
一起使用的示例。
在我的情况下,我有一个形状张量的形状[299,299,3],当我调用一个函数get_distortions(image)
时,它将返回一个新的张量形状[10,299,299,3](在这个例子中,它将对图像应用10个失真并将它们全部作为新的张量返回)。然后,我想通过致电tf.train.batch
将所有这些排队。
我试过了:
example_batch = tf.train.batch(tf.unpack(distortions), 5, enqueue_many=True)
但是当我sess.run(example_batch)
时,我得到一份长度为10的清单(我期待一批5号)。
另外,在这种情况下,如何将标签包含在tf.train.batch
中?所有10种扭曲的标签都是相同的。
答案 0 :(得分:1)
请勿解压缩distortions
。 enqueue_many
的语义是你给它一个张量,第一个维度是批处理维度,因此带有enqueue_many的[10, 299, 299, 3]
张量将导致十个单独的项目,每个项目299, 299, 3
被排队 - 这就是你想要的。
答案 1 :(得分:0)
tf.train.batch
的文档告诉您:
如果enqueue_many为True,则假定张量代表一批 示例,其中第一个维度通过示例索引,以及所有 张量的成员在第一维中应具有相同的大小。 如果输入张量具有形状[*,x,y,z],则输出将具有形状 [batch_size,x,y,z]。 capacity参数控制了多长时间 允许预取会使队列增长。
您的情况恰恰如下:[10, 299, 299, 3]
,其中10是批量大小。因此,您无需进行任何解包,tf.train.batch(distortions, 5, enqueue_many=True)
将完成此任务。