如何在TensorFlow中实现成对距离?

时间:2016-06-28 18:06:22

标签: python numpy tensorflow

我在张量流中实现了平方成对的eucledian距离,但是无法将其与sklearn库匹配。我使用以下方法实现了它:

  

dist(x,y)= dot(x,x) - 2 * dot(x,y)+ dot(y,y)

但无法获得匹配的功能。有谁知道什么可能是错的?这是脚本:

import numpy as np
import tensorflow as tf
import sklearn as sk
from sklearn.metrics.pairwise import euclidean_distances

def get_Z_tf(x,W,S,l='layer'):
    W = tf.Variable(W, name='W'+l, trainable=True, dtype=tf.float64)
    S = tf.Variable(S, name='S'+l, trainable=True, dtype=tf.float64)
    WW =  tf.reduce_sum(W*W, reduction_indices=0, keep_dims=True) #( 1 x D^(l)= sum( (D^(l-1) x D^(l)), 0 )
    XX =  tf.reduce_sum(x*x, reduction_indices=1, keep_dims=True) # (M x 1) = sum( (M x D^(l-1)), 1 )
    # -|| x - w ||^2 = -(-2<x,w> + ||x||^2 + ||w||^2) = 2<x,w> - (||x||^2 + ||w||^2)
    Delta_tilde = 2.0*tf.matmul(x,W) - tf.add(WW, XX)
    return Delta_tilde

def get_z_np(x,W,S):
    WW = np.sum(np.multiply(W,W), axis=0, dtype=None, keepdims=True)
    XX = np.sum(np.multiply(x,x), axis=1, dtype=None, keepdims=True)
    Delta_tilde = 2.0*np.dot(x,W) - (WW + XX)
    return Delta_tilde

W = np.random.rand(4,3)
x = np.random.rand(5,4)
S = 0.9

#x_tf = tf.constant(x)
#sklearn.metrics.pairwise.euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False, X_norm_squared=None)[source]

with tf.Session() as sess:
    x_const = tf.constant(x)
    Z_tf = get_Z_tf(x_const,W,S)
    sess.run(tf.initialize_all_variables())
    Z_tf = sess.run(Z_tf)
    Z_np = get_z_np(x,W,S)
    #Z_sk = S*sklearn.metrics.pairwise.euclidean_distances(X=X,Y=W,squared=True)
    Z_sk = -S*euclidean_distances(X=x,Y=np.transpose(W),squared=True)
    #Z_sk = -S*euclidean_distances(X=np.transpose(x),Y=W,squared=True)

    print 'np'
    print Z_np
    print 'tf'
    print Z_tf
    print 'sk'
    print Z_sk

打印声明的结果:

$ python ../tf_playground/pair_wise_distance.py
np
[[-0.17742723 -0.57233957 -0.54851648]
 [-0.55742866 -0.76165072 -0.64532436]
 [-0.65967836 -2.20533905 -0.5111396 ]
 [-0.74866878 -0.47055669 -1.3993016 ]
 [-1.03158802 -2.17094844 -0.2506851 ]]
tf
[[-0.17742723 -0.57233957 -0.54851648]
 [-0.55742866 -0.76165072 -0.64532436]
 [-0.65967836 -2.20533905 -0.5111396 ]
 [-0.74866878 -0.47055669 -1.3993016 ]
 [-1.03158802 -2.17094844 -0.2506851 ]]
sk
[[-0.15968451 -0.51510561 -0.49366484]
 [-0.5016858  -0.68548565 -0.58079193]
 [-0.59371052 -1.98480515 -0.46002564]
 [-0.6738019  -0.42350102 -1.25937144]
 [-0.92842922 -1.95385359 -0.22561659]]

0 个答案:

没有答案