计算MXNet中的余弦距离

时间:2018-04-02 19:08:32

标签: mxnet

我希望能够使用MXNet计算行向量之间的余弦距离。另外,我正在处理批量样本,并且想要计算每对样本的余弦距离(即批次#1的第1行向量的余弦距离与批次#2的第1行向量)。

两个向量之间的余弦距离定义为scipy.spatial.distance.cosine

enter image description here

1 个答案:

答案 0 :(得分:2)

您可以使用mx.nd.batch_dot执行此批量余弦距离:

import mxnet as mx

def batch_cosine_dist(a, b):
    a1 = mx.nd.expand_dims(a, axis=1)
    b1 = mx.nd.expand_dims(b, axis=2)
    d = mx.nd.batch_dot(a1, b1)[:,0,0]
    a_norm = mx.nd.sqrt(mx.nd.sum((a*a), axis=1))
    b_norm = mx.nd.sqrt(mx.nd.sum((b*b), axis=1))
    dist = 1.0 - d / (a_norm * b_norm)
    return dist

它将返回一个距离为batch_size的数组。

batch_size = 3
dim = 2
a = mx.random.uniform(shape=(batch_size, dim))
b = mx.random.uniform(shape=(batch_size, dim))
dist = batch_cosine_dist(a, b)
print(dist.asnumpy())

# [ 0.04385382  0.25792354  0.10448891]