我有一个值张量和一个重新排序的张量。重新排序张量给出了值张量中每一行的排序。如何使用此重新排序张量实际重新排序值张量中的值。
这会在numpy(Indexing one array by another in numpy)中得到所需的结果:
import numpy as np
values = np.array([
[5,4,100],
[10,20,500]
])
reorder_rows = np.array([
[1,2,0],
[0,2,1]
])
result = values[np.arange(values.shape[0])[:,None],reorder_rows]
print(result)
# [[ 4 100 5]
# [ 10 500 20]]
我怎样才能在tf中做同样的事?
我尝试使用切片和tf.gather_nd,但无法使其正常工作。
感谢。
答案 0 :(得分:1)
以下tf code
应该给出相同的结果:
values = tf.constant([
[5,4,100],
[10,20,500]
])
reorder_rows = tf.constant([
[[0,1],[0,2],[0,0]],
[[1,0],[1,2],[1,1]]
])
result = tf.gather_nd(values, reorder_rows)
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
result.eval()
#Result
#[[ 4, 100, 5],
#[ 10, 500, 20]]
答案 1 :(得分:1)
尝试以下方法:
import numpy as np
values = np.array([
[5,4,100],
[10,20,500]
])
reorder_rows = np.array([
[1,2,0],
[0,2,1]
])
import tensorflow as tf
values = tf.constant(values)
reorder_rows = tf.constant(reorder_rows, dtype=tf.int32)
x = tf.tile(tf.range(tf.shape(values)[0])[:,tf.newaxis], [1,tf.shape(values)[1]])
res = tf.gather_nd(values, tf.stack([x, reorder_rows], axis=-1))
sess = tf.InteractiveSession()
res.eval()