...鉴于
A
[m, n]
I
[m]
我想从J
获取A
个元素列表
J[i] = A[i, I[i]]
。
也就是说,I
保存要从A
中的每一行中选择的元素的索引。
背景信息:我已经拥有argmax(A, 1)
,现在我也想要max
。
我知道我可以使用reduce_max
。
在尝试了一下后,我也提出了这个问题:
J = tf.gather_nd(A,
tf.transpose(tf.pack([tf.to_int64(tf.range(A.get_shape()[0])), I])))
需要to_int64
的地方因为范围只生成int32
而argmax
只生成int64
。
这两个人都没有让我感到特别优雅。
一个具有运行时开销(可能约为因子n
),另一个具有未知因素认知开销。我在这里错过了什么吗?
答案 0 :(得分:3)
这是一个相当晚的答案,但可以做
mask = tf.one_hot(I, depth=n, dtype=tf.bool, on_value=True, off_value=False)
elements = tf.boolean_mask(A, mask)
完成你想要的东西?
编辑:我应该指出,如果A
已经是一个非常大的张量,这不是一个好主意,因为这最终会产生一个密集的矩阵。
答案 1 :(得分:1)
gather()
函数提供了一种实现方法:
r = tf.random.uniform([4,5],0, 9, dtype=tf.int32)
i = tf.random.uniform([4], 0, 4, dtype=tf.int32)
tf.gather(r, i, axis=1, batch_dims=1)
答案 2 :(得分:0)