如何在tensorflow中复制numpy.choose()?

时间:2018-03-12 17:41:47

标签: python numpy tensorflow

我正在努力有效地复制numpy的ndarray.choose()方法。

这是我正在寻找的一个例子:

b = np.arange(15).reshape(3, 5)
c = np.array([1,0,4])
c.choose(b.T)  # trying to replicate in tensorflow
  -> array([ 1,  5, 14]) 

我能够做到的最好的事情就是生成一个batch_size方阵(如果批量很大,这个矩阵很大)并采用它的对角线:

tf_b = tf.constant(b)
tf_c = tf.constant(c)
sess.run(tf.diag_part(tf.gather(tf.transpose(tf_b), tf_c)))
-> array([ 1,  5, 14])

有没有办法做到这一点,在第一维(而不是平方)只是线性的?

1 个答案:

答案 0 :(得分:2)

是的,有更简单的方法可以做到这一点。将您的free数组展平为1-d,使其成为b。获取一系列索引,这些索引的数量范围在'选择范围内。你正在服用(3个你的情况)。那将是[0, 1, 2, ..., 13, 14]。将此范围乘以原始形状的第二个维度,即每个选项的选项数量(在您的情况下为5)。这会给你[0, 1, 2]。然后将索引添加到此处以获取[0, 5, 10]。现在你打电话给tf.gather()很好。

以下是我从here获取的一些代码,它为RNN输出做了类似的事情。你的会有所不同,但想法是一样的。

[1, 5, 14]

总体而言,操作非常简单。您使用范围操作来获取每行开头的索引,然后添加每行的位置索引。我认为在1D中这样做是最简单的,这就是我们为什么要压扁它的原因。