我需要从Tensor收集一些数据,我使用了gather_nd
。现在代码在
import tensorflow as tf
indices = [[[0, 4], [0, 1], [0, 6], [0, 2]],
[[1, 1], [1, 4], [1, 0], [1, 9]],
[[2, 5], [2, 1], [2, 9], [2, 6]]]
params = [[4,6,3,6,7,8,4,5,3,8], [9,5,6,2,6,5,1,9,6,4], [4,6,6,1,3,2,6,7,1,8]]
output = tf.gather_nd(params, indices)
sess = tf.Session()
print sess.run(output)
输出
[[7 6 4 3]
[5 6 9 4]
[2 6 8 6]]
是的,这就是我想要的。我想取出位于params [0]的4,1,6,2处的值。它们是7,6,4,3,因为params [0] [4] = 7,params [0] [1] = 6,params [0] [6] = 4,params [0] [2] = 3。
但是,tf.gather_nd
只接收上述索引。现在我的raw_indices就像,
[[4, 1, 6, 2],
[1, 4, 0, 9],
[5, 1, 9, 6]]
如何在张量流中将raw_indices
转移到indices
?是的,我必须在张量图中执行此步骤,因为在图的中间生成了raw_indices
。
答案 0 :(得分:1)
tf.range()和一些拼贴的混合似乎有效:
def index_matrix_to_pairs(index_matrix):
replicated_first_indices = tf.tile(
tf.expand_dims(tf.range(tf.shape(index_matrix)[0]), dim=1),
[1, tf.shape(index_matrix)[1]])
return tf.pack([replicated_first_indices, index_matrix], axis=2)
start = [[4, 1, 6, 2],
[1, 4, 0, 9],
[5, 1, 9, 6]]
with tf.Session():
print(index_matrix_to_pairs(start).eval())
给出:
[[[0 4]
[0 1]
[0 6]
[0 2]]
[[1 1]
[1 4]
[1 0]
[1 9]]
[[2 5]
[2 1]
[2 9]
[2 6]]]
它只是使用平铺的tf.range()op生成每对的第一部分,然后用指定的索引打包。