例如,我有张量
x = tf.constant([[1, 2], [1, 2], [2, 3], [4, 5], [4,5]])
然后我有一个聚集索引列表
idx = [[0,1],[2], [3, 4]]
并应用于x
并取每个张量的均值
y = []
for i in idx:
y.append(tf.reduce_mean(tf.gather(x, i, 0), 0))
最后,将它们堆叠在一起
y = tf.stack(y, 0)
我想得到的结果是
tensor([[1, 2], [2, 3], [4, 5]])
它有问题,for循环效率不高,有人可以帮我解决吗?
答案 0 :(得分:0)
这对您有用吗?请验证。请注意,出于某些原因,它会打印浮点数,并且Tensorflow的版本为1.13
。我还不确定这与您的相比有多有效,因为我没有微基准测试。
x = tf.constant([[1, 2], [1, 2], [2, 3], [4, 5], [4,5]])
print( sess.run(tf.reduce_mean(tf.gather(x,tf.ragged.constant([[0,1],[2], [3, 4]]),0),1) ))
打印
[[1. 2.]
[2. 3.]
[4. 5.]]