使用tf.gather获取张量的值

时间:2017-07-26 17:57:21

标签: tensorflow

我有一个形状的张量(2,3),比如input = [[1 2 3] [4 5 6]],我有一个形状的索引张量(2,3),我希望用它来从输入中检索值,{{1 }}。我的预期结果是index = [[1 0] [2 0]]。但是,仅使用result = [[2 1] [6 4]]似乎无效。

2 个答案:

答案 0 :(得分:2)

如果要从数组中提取元素,可以使用gather_nd,并且索引的格式应为每个元素的(i,j)。在您的示例中,您的索引应为:

inputs = tf.Variable([[1, 2, 3], [4, 5, 6]])
index = tf.Variable([[[0,1],[0,0]], [[1,2],[1,0]]])

result = tf.gather_nd(inputs, new_index)

sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
print(result.eval())
# output
# [[2 1]
# [6 4]]

如果你想从你提到的表格中生成索引,你可以这样做:

index = tf.Variable([[1, 0], [2, 0]])

r = tf.tile(tf.expand_dims(tf.range(tf.shape(index1)[0]), 1), [1, 2])
new_index = tf.concat([tf.expand_dims(r,-1), tf.expand_dims(index, -1)], axis=2)

sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
print(new_index.eval())
#output
#[[[0 1]
#[0 0]]
#[[1 2]
#[1 0]]]

答案 1 :(得分:0)

问题出在index,您必须在index中使用0或1个值,因为您的input数组的形状为(2,3)。如果向input数组中添加其他行,则所有工作正常:

import tensorflow as tf

input = tf.Variable([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
index = tf.Variable([[1, 0], [2, 0]])
result = tf.gather(input, index)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(result))
# results [[[4 5 6] [1 2 3]] [[7 8 9] [1 2 3]]]

无论如何,index描述从input数组收集的切片,而不是元素