如何使用TensorFlow有效地计算矩阵中的成对余弦距离?给定MxN
矩阵,结果应为MxM
矩阵,其中位置[i][j]
处的元素是输入矩阵中第i行和第j行/向量之间的余弦距离
这可以通过Scikit-Learn轻松完成,如下所示:
from sklearn.metrics.pairwise import pairwise_distances
pairwise_distances(input_matrix, metric='cosine')
TensorFlow中是否有等效的方法?
答案 0 :(得分:5)
这里有一个获得单个余弦距离的答案:https://stackoverflow.com/a/46057597/288875。这基于tf.losses.cosine_distance。
这是一个为矩阵执行此操作的解决方案:
import tensorflow as tf
import numpy as np
with tf.Session() as sess:
M = 3
# input
input = tf.placeholder(tf.float32, shape = (M, M))
# normalize each row
normalized = tf.nn.l2_normalize(input, dim = 1)
# multiply row i with row j using transpose
# element wise product
prod = tf.matmul(normalized, normalized,
adjoint_b = True # transpose second matrix
)
dist = 1 - prod
input_matrix = np.array(
[[ 1, 1, 1 ],
[ 0, 1, 1 ],
[ 0, 0, 1 ],
],
dtype = 'float32')
print "input_matrix:"
print input_matrix
from sklearn.metrics.pairwise import pairwise_distances
print "sklearn:"
print pairwise_distances(input_matrix, metric='cosine')
print "tensorflow:"
print sess.run(dist, feed_dict = { input : input_matrix })
给了我:
input_matrix:
[[ 1. 1. 1.]
[ 0. 1. 1.]
[ 0. 0. 1.]]
sklearn:
[[ 0. 0.18350345 0.42264974]
[ 0.18350345 0. 0.29289323]
[ 0.42264974 0.29289323 0. ]]
tensorflow:
[[ 5.96046448e-08 1.83503449e-01 4.22649741e-01]
[ 1.83503449e-01 5.96046448e-08 2.92893231e-01]
[ 4.22649741e-01 2.92893231e-01 0.00000000e+00]]
请注意,此解决方案可能不是最佳解决方案,因为它计算(对称)结果矩阵的所有条目,即几乎是计算的两倍。这对于小矩阵来说可能不是问题,对于大型矩阵,循环的组合可能更快。
另请注意,这不具有小批量维度,因此仅适用于单个矩阵。
答案 1 :(得分:0)
优雅的解决方案(输出与scikit-learn pairwise_distances函数的输出相同):
def compute_cosine_distances(a, b):
# x shape is n_a * dim
# y shape is n_b * dim
# results shape is n_a * n_b
normalize_a = tf.nn.l2_normalize(a,1)
normalize_b = tf.nn.l2_normalize(b,1)
similarity = 1 - tf.matmul(normalize_a, normalize_b, transpose_b=True)
return similarity
测试
input_matrix = np.array([[1, 1, 1],
[0, 1, 1],
[0, 0, 1]], dtype = 'float32')
compute_cosine_distances(input_matrix, input_matrix)
输出:
<tf.Tensor: id=442, shape=(3, 3), dtype=float32, numpy=
array([[5.9604645e-08, 1.8350345e-01, 4.2264974e-01],
[1.8350345e-01, 5.9604645e-08, 2.9289323e-01],
[4.2264974e-01, 2.9289323e-01, 0.0000000e+00]], dtype=float32)>