我有一个形状的张量(2,3),比如input = [[1 2 3] [4 5 6]]
,我有一个形状的索引张量(2,3),我希望用它来从输入中检索值,{{1 }}。我的预期结果是index = [[1 0] [2 0]]
。但是,仅使用result = [[2 1] [6 4]]
似乎无效。
答案 0 :(得分:2)
如果要从数组中提取元素,可以使用gather_nd
,并且索引的格式应为每个元素的(i,j)
。在您的示例中,您的索引应为:
inputs = tf.Variable([[1, 2, 3], [4, 5, 6]])
index = tf.Variable([[[0,1],[0,0]], [[1,2],[1,0]]])
result = tf.gather_nd(inputs, new_index)
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
print(result.eval())
# output
# [[2 1]
# [6 4]]
如果你想从你提到的表格中生成索引,你可以这样做:
index = tf.Variable([[1, 0], [2, 0]])
r = tf.tile(tf.expand_dims(tf.range(tf.shape(index1)[0]), 1), [1, 2])
new_index = tf.concat([tf.expand_dims(r,-1), tf.expand_dims(index, -1)], axis=2)
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
print(new_index.eval())
#output
#[[[0 1]
#[0 0]]
#[[1 2]
#[1 0]]]
答案 1 :(得分:0)
问题出在index
,您必须在index
中使用0或1个值,因为您的input
数组的形状为(2,3)
。如果向input
数组中添加其他行,则所有工作正常:
import tensorflow as tf
input = tf.Variable([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
index = tf.Variable([[1, 0], [2, 0]])
result = tf.gather(input, index)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(result))
# results [[[4 5 6] [1 2 3]] [[7 8 9] [1 2 3]]]
无论如何,index
描述从input
数组收集的切片,而不是元素。