如何使用tensorflow中的gather_nd在我的案例中收集数据?

时间:2016-11-01 14:22:18

标签: python tensorflow

我需要从Tensor收集一些数据,我使用了gather_nd。现在代码在

之上
import tensorflow as tf

indices = [[[0, 4], [0, 1], [0, 6], [0, 2]],
           [[1, 1], [1, 4], [1, 0], [1, 9]],
           [[2, 5], [2, 1], [2, 9], [2, 6]]]

params = [[4,6,3,6,7,8,4,5,3,8], [9,5,6,2,6,5,1,9,6,4], [4,6,6,1,3,2,6,7,1,8]]
output = tf.gather_nd(params, indices)

sess = tf.Session()
print sess.run(output)

输出

[[7 6 4 3]
 [5 6 9 4]
 [2 6 8 6]]
是的,这就是我想要的。我想取出位于params [0]的4,1,6,2处的值。它们是7,6,4,3,因为params [0] [4] = 7,params [0] [1] = 6,params [0] [6] = 4,params [0] [2] = 3。

但是,tf.gather_nd只接收上述索引。现在我的raw_indices就像,

[[4, 1, 6, 2],
 [1, 4, 0, 9],
 [5, 1, 9, 6]]

如何在张量流中将raw_indices转移到indices?是的,我必须在张量图中执行此步骤,因为在图的中间生成了raw_indices

1 个答案:

答案 0 :(得分:1)

tf.range()和一些拼贴的混合似乎有效:

def index_matrix_to_pairs(index_matrix):
  replicated_first_indices = tf.tile(
      tf.expand_dims(tf.range(tf.shape(index_matrix)[0]), dim=1), 
      [1, tf.shape(index_matrix)[1]])
  return tf.pack([replicated_first_indices, index_matrix], axis=2)

start = [[4, 1, 6, 2],
 [1, 4, 0, 9],
 [5, 1, 9, 6]]
with tf.Session():
  print(index_matrix_to_pairs(start).eval())

给出:

[[[0 4]
  [0 1]
  [0 6]
  [0 2]]

 [[1 1]
  [1 4]
  [1 0]
  [1 9]]

 [[2 5]
  [2 1]
  [2 9]
  [2 6]]]

它只是使用平铺的tf.range()op生成每对的第一部分,然后用指定的索引打包。