如何通过保留特定维度的索引元素来减小张量的维度?

时间:2020-11-04 21:01:13

标签: python tensorflow tensorflow2.0 tensor

减小维度类似于reduce_max()所做的事情,不同之处在于我想要该维度中元素的特定索引,而不是简单地选择最大索引。例如,我有一个2x3张量A = [[0,1,2],[2,2,0]]。如果我应用tf.argmax(A),则会得到索引张量[1,1,0]。如何使用该索引张量[1,1,0]来获取张量为tf.reduce_max(A,0)= [2,2,2]?

我不直接使用tf.reduce_max的原因是我想使用不同的索引张量而不是argmax索引张量来减小维数或保持索引值而不是该维数的最大值。

1 个答案:

答案 0 :(得分:2)

您可以使用tf.gather_nd函数来执行此操作,但是您需要将[1, 1, 0]索引张量转换为2D张量。

这里我假设索引张量是一个numpy数组(您可以通过调用.numpy()方法将tensorflow张量转换为numpy数组。

idx = np.array([1, 1, 0])

idx = np.c_[idx[:, np.newaxis], np.arange(len(idx))]

print(idx)

# Output:
# array([[1, 0],
#        [1, 1],
#        [0, 2]])

这意味着:使用上述tf.gather_nd时选择(row1,col0),(row1,col1)和(row0,col2)

A = tf.Variable([[0, 1, 2], [2, 2, 0]])

tf.gather_nd(A, idx)

将为您提供预期的[2, 2, 2]张量。