让我们考虑一个numpy矩阵,o
:
如果我们想使用numpy来使用以下功能:
o[np.arange(x), column_array]
我可以一次从numpy数组中获取多个索引。
我试图用tensorflow做同样的事情,但它并没有像我所做的那样工作。当o
是张量流张量时;
o[tf.range(0, x, 1), column_array]
我收到以下错误:
TypeError: can only concatenate list (not "int") to list
我该怎么办?
答案 0 :(得分:4)
您可以尝试tf.gather_nd()
,这篇文章建议为How to select rows from a 3-D Tensor in TensorFlow?。
以下是从矩阵o
获取多个索引的示例。
o = tf.constant([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]])
# [row_index, column_index], I don’t figure out how to
# combine row vector and column vector into this form.
indices = tf.constant([[0, 0], [0, 1], [2, 1], [2, 3]])
result = tf.gather_nd(o, indices)
with tf.Session() as sess:
print(sess.run(result)) #[ 1 2 10 12]
答案 1 :(得分:1)
您可能希望看到tf.gather_nd
:https://www.tensorflow.org/api_docs/python/tf/gather_nd
import tensorflow as tf
import numpy as np
tensor = tf.placeholder(tf.float32, [2,2])
indices = tf.placeholder(tf.int32, [2,2])
selected = tf.gather_nd(tensor, indices=indices)
with tf.Session() as session:
data = np.array([[0.1,0.2],[0.3,0.4]])
idx = np.array([[0,0],[1,1]])
result = session.run(selected, feed_dict={indices:idx, tensor:data})
print(result)
,结果为[ 0.1 0.40000001]