我想计算存储在某些2D张量中的所有点之间的成对距离('d'维度中的'batch_size'点)。
我目前正在使用下面的tensorflow代码,这很不错,但是我希望有一个更快的解决方案,因为每次迭代的时间成本在很大程度上取决于此步骤。
'''python
Z = tf.Variable(np.empty((batch_size,d))
Z = tf.reshape(Zy, (batch_size, 1, d)) -
tf.reshape(Z, (1, batch_size, d))
norms = tf.norm(Z, axis=2)
'''
基本上,这是做一个中间3D张量Z,其中Z(i,j,k)表示点i和点j在第k个坐标中的差。有什么想法吗?