如何在enqueue_many = true

时间:2016-03-02 00:23:12

标签: tensorflow

我正在寻找将tf.train.batchenqueue_many=True一起使用的示例。

在我的情况下,我有一个形状张量的形状[299,299,3],当我调用一个函数get_distortions(image)时,它将返回一个新的张量形状[10,299,299,3](在这个例子中,它将对图像应用10个失真并将它们全部作为新的张量返回)。然后,我想通过致电tf.train.batch将所有这些排队。

我试过了:

example_batch = tf.train.batch(tf.unpack(distortions), 5, enqueue_many=True)

但是当我sess.run(example_batch)时,我得到一份长度为10的清单(我期待一批5号)。

另外,在这种情况下,如何将标签包含在tf.train.batch中?所有10种扭曲的标签都是相同的。

2 个答案:

答案 0 :(得分:1)

请勿解压缩distortionsenqueue_many的语义是你给它一个张量,第一个维度是批处理维度,因此带有enqueue_many的[10, 299, 299, 3]张量将导致十个单独的项目,每个项目299, 299, 3被排队 - 这就是你想要的。

答案 1 :(得分:0)

tf.train.batch的文档告诉您:

  

如果enqueue_many为True,则假定张量代表一批   示例,其中第一个维度通过示例索引,以及所有   张量的成员在第一维中应具有相同的大小。   如果输入张量具有形状[*,x,y,z],则输出将具有形状   [batch_size,x,y,z]。 capacity参数控制了多长时间   允许预取会使队列增长。

您的情况恰恰如下:[10, 299, 299, 3],其中10是批量大小。因此,您无需进行任何解包,tf.train.batch(distortions, 5, enqueue_many=True)将完成此任务。