在张量流

时间:2017-05-05 07:25:24

标签: python numpy tensorflow

让我们考虑一个numpy矩阵,o

如果我们想使用numpy来使用以下功能:

o[np.arange(x), column_array]

我可以一次从numpy数组中获取多个索引。

我试图用tensorflow做同样的事情,但它并没有像我所做的那样工作。当o是张量流张量时;

o[tf.range(0, x, 1), column_array]

我收到以下错误:

TypeError: can only concatenate list (not "int") to list

我该怎么办?

2 个答案:

答案 0 :(得分:4)

您可以尝试tf.gather_nd(),这篇文章建议为How to select rows from a 3-D Tensor in TensorFlow?。 以下是从矩阵o获取多个索引的示例。

o = tf.constant([[1, 2, 3, 4],
                 [5, 6, 7, 8],
                 [9, 10, 11, 12],
                 [13, 14, 15, 16]])
# [row_index, column_index], I don’t figure out how to
# combine row vector and column vector into this form.
indices = tf.constant([[0, 0], [0, 1], [2, 1], [2, 3]])

result = tf.gather_nd(o, indices)

with tf.Session() as sess:
    print(sess.run(result)) #[ 1  2 10 12]

答案 1 :(得分:1)

您可能希望看到tf.gather_ndhttps://www.tensorflow.org/api_docs/python/tf/gather_nd

import tensorflow as tf
import numpy as np

tensor = tf.placeholder(tf.float32, [2,2])
indices = tf.placeholder(tf.int32, [2,2])
selected = tf.gather_nd(tensor, indices=indices)

with tf.Session() as session:
    data = np.array([[0.1,0.2],[0.3,0.4]])
    idx = np.array([[0,0],[1,1]])
    result = session.run(selected, feed_dict={indices:idx, tensor:data})
    print(result)

,结果为[ 0.1 0.40000001]