Tensorflow:从一个张量中选择另一个张量的项目

时间:2017-07-18 11:31:26

标签: python tensorflow slice

我有一个值张量和一个重新排序的张量。重新排序张量给出了值张量中每一行的排序。如何使用此重新排序张量实际重新排序值张量中的值。

这会在numpy(Indexing one array by another in numpy)中得到所需的结果:

import numpy as np
values = np.array([
    [5,4,100],
    [10,20,500]
])
reorder_rows = np.array([
    [1,2,0],
    [0,2,1]
])


result = values[np.arange(values.shape[0])[:,None],reorder_rows]
print(result)

# [[  4 100   5]
#  [ 10 500  20]]

我怎样才能在tf中做同样的事?

我尝试使用切片和tf.gather_nd,但无法使其正常工作。

感谢。

2 个答案:

答案 0 :(得分:1)

以下tf code应该给出相同的结果:

values = tf.constant([
 [5,4,100],
 [10,20,500]
])
reorder_rows = tf.constant([
   [[0,1],[0,2],[0,0]],
   [[1,0],[1,2],[1,1]]
])
result = tf.gather_nd(values, reorder_rows)
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
result.eval()

#Result
#[[  4, 100,   5],
#[ 10, 500,  20]]

答案 1 :(得分:1)

尝试以下方法:

import numpy as np
values = np.array([
    [5,4,100],
    [10,20,500]
])
reorder_rows = np.array([
    [1,2,0],
    [0,2,1]
])

import tensorflow as tf

values = tf.constant(values)
reorder_rows = tf.constant(reorder_rows, dtype=tf.int32)
x = tf.tile(tf.range(tf.shape(values)[0])[:,tf.newaxis], [1,tf.shape(values)[1]])
res = tf.gather_nd(values, tf.stack([x, reorder_rows], axis=-1))
sess = tf.InteractiveSession()
res.eval()