我正在努力有效地复制numpy的ndarray.choose()
方法。
这是我正在寻找的一个例子:
b = np.arange(15).reshape(3, 5)
c = np.array([1,0,4])
c.choose(b.T) # trying to replicate in tensorflow
-> array([ 1, 5, 14])
我能够做到的最好的事情就是生成一个batch_size方阵(如果批量很大,这个矩阵很大)并采用它的对角线:
tf_b = tf.constant(b)
tf_c = tf.constant(c)
sess.run(tf.diag_part(tf.gather(tf.transpose(tf_b), tf_c)))
-> array([ 1, 5, 14])
有没有办法做到这一点,在第一维(而不是平方)只是线性的?
答案 0 :(得分:2)
是的,有更简单的方法可以做到这一点。将您的free
数组展平为1-d,使其成为b
。获取一系列索引,这些索引的数量范围在'选择范围内。你正在服用(3个你的情况)。那将是[0, 1, 2, ..., 13, 14]
。将此范围乘以原始形状的第二个维度,即每个选项的选项数量(在您的情况下为5)。这会给你[0, 1, 2]
。然后将索引添加到此处以获取[0, 5, 10]
。现在你打电话给tf.gather()很好。
以下是我从here获取的一些代码,它为RNN输出做了类似的事情。你的会有所不同,但想法是一样的。
[1, 5, 14]
总体而言,操作非常简单。您使用范围操作来获取每行开头的索引,然后添加每行的位置索引。我认为在1D中这样做是最简单的,这就是我们为什么要压扁它的原因。