Tensorflow通过另一个张量进行多维索引

时间:2017-09-21 17:22:24

标签: python numpy indexing tensorflow

说,我有一个形状[n1,n2,...,nk]的等级k张量X和形状[n2,n3,...,nk]的等级 - (k-1)张量IDX其中IDX具有与X的最后(k-1)维相同的形状.IDX的条目都是[0,n1]中的整数。我想从X中获取一些值,其中第一个维度位置由IDX指定,而其他维度则全部迭代。

示例:

X = tf.constant([[1,2], [3,4], [5,6],
                 [7,8], [9,10],[11,12]]) # 2 x 3 x 2 tensor
IDX = tf.constant([[1,0], [1,1], [0,1]]) #     3 x 2 tensor
...
# would like to get [[7,2],[9,10],[5,12]]

如何在Tensorflow中有效实现这一目标?谢谢!

2 个答案:

答案 0 :(得分:0)

您是否看到了choose的说明?

  

注释

     

为了减少误解的可能性,即使如下   "滥用"在名义上得到支持,choices既不应该,也不应该   被认为是单个阵列,即最外层的序列式容器   应该是列表或元组。

也就是说,他们希望你像对待它一样:

In [432]: list(X)
Out[432]: [array([1, 2]), array([3, 4]), array([5, 6])]
In [433]: np.choose(IDX,list(X))
Out[433]: array([3, 6])

索引等效项为:

In [436]: X[IDX,np.arange(2)]
Out[436]: array([3, 6])

choose也有一些mode选项。

文档也说它相当于(减去这些模式问题):

np.choose(a,c) == np.array([c[a[I]][I] for I in ndi.ndindex(a.shape)])

choose的另一个细微差别。它无法使用超过32种选择。

In [440]: np.choose(IDX,np.ones((33,2)))
...
ValueError: Need at least 1 and at most 32 array objects.

In [442]: np.ones((33,2))[IDX,np.arange(2)]
Out[442]: array([ 1.,  1.])

答案 1 :(得分:0)

您可以将np.choose()包装在python函数中,并使用tf.py_func()将其嵌入到tensorflow图中。但是,如果您希望自动梯度计算图表以便为您提供培训,您还可以为函数定义渐变。如果实际上可以解决的话,定义np.choose()的渐变可能是非常棘手的任务。