pytorch中的向量行和矩阵行余弦相似度

时间:2019-01-04 15:59:02

标签: pytorch cosine-similarity

在pytorch中,我有多个(十万个标度)300个暗淡矢量(我认为我应该在矩阵中上载),我想通过它们与另一个矢量的余弦相似度对它们进行排序,并提取出前1000个。我想避免for循环,因为这很耗时。我一直在寻找有效的解决方案。

1 个答案:

答案 0 :(得分:0)

您可以使用torch.nn.functional.cosine_similarity函数来计算余弦相似度。然后torch.argsort提取前1000名。

这里是一个例子:

x = torch.rand(10000,300)
y = torch.rand(1,300)
dist = F.cosine_similarity(x,y)
index_sorted = torch.argsort(dist)
top_1000 = index_sorted[:1000]

请注意y的形状,在调用相似函数之前不要忘记重塑形状。另请注意,argsort仅返回最接近向量的索引。要访问这些向量本身,只需编写x[top_1000],它将返回一个形状为(1000,300)的矩阵。