TensorFlow提取列

时间:2017-06-19 20:10:43

标签: python tensorflow

我有一个形状的张量(10,100,20,3)。基本上,它可以被认为是一批图像。因此图像高度为100,宽度为20,通道深度为3。

我已经运行了一些计算来生成一组10 * 50个索引,这些索引对应于我希望在批处理中保留每个图像的50列。指数以张量形状(10,50)存储。我想最终得到一个形状的张量(10,50,20,3)。

我已经研究过tf.batch_nd(),但我无法弄清楚索引实际使用的语义。

有什么想法吗?

1 个答案:

答案 0 :(得分:0)

我无法对这个问题发表评论,因为代表率较低,所以请改为使用答案。

您能否稍微澄清一下您的问题,或许是使用非常小的张量的一个小具体例子?

您所指的“列”是什么?你说你想为每张图片保留50列(大概是50个数字)。如果是这样,(10,50)形状看起来就像你想要的那样 - 它对批次中的每个图像都有50个数字。您提到的(10,50,20,3)形状将为每个“image_column x channel”分配50个数字。即每张图像20 * 3 * 50 = 3000个数字。你想如何用你拥有的50来构建它们?

另外,您能否提供tf.batch_nd()的链接。我没有发现任何类似和相关的东西。