减小维度类似于reduce_max()所做的事情,不同之处在于我想要该维度中元素的特定索引,而不是简单地选择最大索引。例如,我有一个2x3张量A = [[0,1,2],[2,2,0]]。如果我应用tf.argmax(A),则会得到索引张量[1,1,0]。如何使用该索引张量[1,1,0]来获取张量为tf.reduce_max(A,0)= [2,2,2]?
我不直接使用tf.reduce_max的原因是我想使用不同的索引张量而不是argmax索引张量来减小维数或保持索引值而不是该维数的最大值。
答案 0 :(得分:2)
您可以使用tf.gather_nd
函数来执行此操作,但是您需要将[1, 1, 0]
索引张量转换为2D张量。
这里我假设索引张量是一个numpy数组(您可以通过调用.numpy()
方法将tensorflow张量转换为numpy数组。
idx = np.array([1, 1, 0])
idx = np.c_[idx[:, np.newaxis], np.arange(len(idx))]
print(idx)
# Output:
# array([[1, 0],
# [1, 1],
# [0, 2]])
这意味着:使用上述tf.gather_nd
时选择(row1,col0),(row1,col1)和(row0,col2)
A = tf.Variable([[0, 1, 2], [2, 2, 0]])
tf.gather_nd(A, idx)
将为您提供预期的[2, 2, 2]
张量。