我希望能够使用MXNet计算行向量之间的余弦距离。另外,我正在处理批量样本,并且想要计算每对样本的余弦距离(即批次#1的第1行向量的余弦距离与批次#2的第1行向量)。
两个向量之间的余弦距离定义为scipy.spatial.distance.cosine:
答案 0 :(得分:2)
您可以使用mx.nd.batch_dot
执行此批量余弦距离:
import mxnet as mx
def batch_cosine_dist(a, b):
a1 = mx.nd.expand_dims(a, axis=1)
b1 = mx.nd.expand_dims(b, axis=2)
d = mx.nd.batch_dot(a1, b1)[:,0,0]
a_norm = mx.nd.sqrt(mx.nd.sum((a*a), axis=1))
b_norm = mx.nd.sqrt(mx.nd.sum((b*b), axis=1))
dist = 1.0 - d / (a_norm * b_norm)
return dist
它将返回一个距离为batch_size
的数组。
batch_size = 3
dim = 2
a = mx.random.uniform(shape=(batch_size, dim))
b = mx.random.uniform(shape=(batch_size, dim))
dist = batch_cosine_dist(a, b)
print(dist.asnumpy())
# [ 0.04385382 0.25792354 0.10448891]