在Tensorflow中每行选择一个元素的优雅方法

时间:2016-05-04 11:24:05

标签: tensorflow

...鉴于

  • 形状为A
  • 的矩阵[m, n]
  • 形状I
  • 的张量[m]

我想从J获取A个元素列表 J[i] = A[i, I[i]]

也就是说,I保存要从A中的每一行中选择的元素的索引。

背景信息:我已经拥有argmax(A, 1),现在我也想要max。 我知道我可以使用reduce_max。 在尝试了一下后,我也提出了这个问题:

J = tf.gather_nd(A,
    tf.transpose(tf.pack([tf.to_int64(tf.range(A.get_shape()[0])), I])))

需要to_int64的地方因为范围只生成int32argmax只生成int64

这两个人都没有让我感到特别优雅。 一个具有运行时开销(可能约为因子n),另一个具有未知因素认知开销。我在这里错过了什么吗?

3 个答案:

答案 0 :(得分:3)

这是一个相当晚的答案,但可以做

mask = tf.one_hot(I, depth=n, dtype=tf.bool, on_value=True, off_value=False)
elements = tf.boolean_mask(A, mask)

完成你想要的东西?

编辑:我应该指出,如果A已经是一个非常大的张量,这不是一个好主意,因为这最终会产生一个密集的矩阵。

答案 1 :(得分:1)

gather()函数提供了一种实现方法:

r = tf.random.uniform([4,5],0, 9, dtype=tf.int32)
i = tf.random.uniform([4], 0, 4, dtype=tf.int32)
tf.gather(r, i, axis=1, batch_dims=1)

答案 2 :(得分:0)

由@ yaroslav-bulatov提供的Link提及this解决方案:

def get_elements(data, indices):
  indeces = tf.range(0, tf.shape(indices)[0])*data.shape[1] + indices
  return tf.gather(tf.reshape(data, [-1]), indeces)

您的解决方案目前无法区分(因为目前不支持tf.gather_nd的渐变)。

希望很快就会推出data[:, indices]