我在张量流中实现了平方成对的eucledian距离,但是无法将其与sklearn库匹配。我使用以下方法实现了它:
dist(x,y)= dot(x,x) - 2 * dot(x,y)+ dot(y,y)
但无法获得匹配的功能。有谁知道什么可能是错的?这是脚本:
import numpy as np
import tensorflow as tf
import sklearn as sk
from sklearn.metrics.pairwise import euclidean_distances
def get_Z_tf(x,W,S,l='layer'):
W = tf.Variable(W, name='W'+l, trainable=True, dtype=tf.float64)
S = tf.Variable(S, name='S'+l, trainable=True, dtype=tf.float64)
WW = tf.reduce_sum(W*W, reduction_indices=0, keep_dims=True) #( 1 x D^(l)= sum( (D^(l-1) x D^(l)), 0 )
XX = tf.reduce_sum(x*x, reduction_indices=1, keep_dims=True) # (M x 1) = sum( (M x D^(l-1)), 1 )
# -|| x - w ||^2 = -(-2<x,w> + ||x||^2 + ||w||^2) = 2<x,w> - (||x||^2 + ||w||^2)
Delta_tilde = 2.0*tf.matmul(x,W) - tf.add(WW, XX)
return Delta_tilde
def get_z_np(x,W,S):
WW = np.sum(np.multiply(W,W), axis=0, dtype=None, keepdims=True)
XX = np.sum(np.multiply(x,x), axis=1, dtype=None, keepdims=True)
Delta_tilde = 2.0*np.dot(x,W) - (WW + XX)
return Delta_tilde
W = np.random.rand(4,3)
x = np.random.rand(5,4)
S = 0.9
#x_tf = tf.constant(x)
#sklearn.metrics.pairwise.euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False, X_norm_squared=None)[source]
with tf.Session() as sess:
x_const = tf.constant(x)
Z_tf = get_Z_tf(x_const,W,S)
sess.run(tf.initialize_all_variables())
Z_tf = sess.run(Z_tf)
Z_np = get_z_np(x,W,S)
#Z_sk = S*sklearn.metrics.pairwise.euclidean_distances(X=X,Y=W,squared=True)
Z_sk = -S*euclidean_distances(X=x,Y=np.transpose(W),squared=True)
#Z_sk = -S*euclidean_distances(X=np.transpose(x),Y=W,squared=True)
print 'np'
print Z_np
print 'tf'
print Z_tf
print 'sk'
print Z_sk
打印声明的结果:
$ python ../tf_playground/pair_wise_distance.py
np
[[-0.17742723 -0.57233957 -0.54851648]
[-0.55742866 -0.76165072 -0.64532436]
[-0.65967836 -2.20533905 -0.5111396 ]
[-0.74866878 -0.47055669 -1.3993016 ]
[-1.03158802 -2.17094844 -0.2506851 ]]
tf
[[-0.17742723 -0.57233957 -0.54851648]
[-0.55742866 -0.76165072 -0.64532436]
[-0.65967836 -2.20533905 -0.5111396 ]
[-0.74866878 -0.47055669 -1.3993016 ]
[-1.03158802 -2.17094844 -0.2506851 ]]
sk
[[-0.15968451 -0.51510561 -0.49366484]
[-0.5016858 -0.68548565 -0.58079193]
[-0.59371052 -1.98480515 -0.46002564]
[-0.6738019 -0.42350102 -1.25937144]
[-0.92842922 -1.95385359 -0.22561659]]