tensorflow聚集在多个维度

时间:2016-06-08 22:42:01

标签: python arrays numpy tensorflow

gather(params, indices)执行以下操作

output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :]

所以,如果你有4维参数和2维索引,你最终会得到5维数组

问题是怎么做

output[i, ..., j, :, ... :] = params[indices[i, :], ..., indices[j, :], :, ..., :]

以便它充当numpy's

output = params[indices[0], indices[1], .. , :]

(github上的#206票是关于不同的问题:它是关于类似numpy的api,而不是一般的聚会)

一种可能的方法是使用gather_nd,但是(据我所知),如果我们想要gather_nd并非所有维度,我们仍然需要为它们创建索引,例如如果我们有10维数组A并且我们想用二维数组B索引前两个维度,比如A [B [0],B [1],],我们的索引矩阵必须有11列(含8个冗余)。

--- old indices ----       new index
0 0 <all rows of length 8> 0
1 1 <all rows of length 8> 1
...

1 个答案:

答案 0 :(得分:1)

#206有一个更新@ebrevdo正致力于推广切片。

与此同时,你可以展平你的数组,为你想要的元素构建线性索引,使用聚集,然后重塑形状,就像mrry在another answer中所做的那样。效率可能不比原生实现差得多