在Keras中,如何使用dot()计算张量与常数矩阵的每一行之间的余弦接近度?

时间:2019-01-16 04:52:46

标签: python keras

我有一个jdes的张量(?, 100)和一个形状为jt_six的常数矩阵(6,100)。而且,我试图得到jdesjt_six的每一行的余弦接近度的结果,并且结果的形状应为(?, 6)。我看到dot()层能够计算余弦接近度设置normalize=True,但是有了我的代码,我得到了形状为(6,1)的结果,其中没有批量大小。有人可以帮我吗?

def dot_similarity(jdes):
    jdes = K.l2_normalize(jdes, axis=-1) # (?, 100)
    jt_six = K.l2_normalize(K.variable(jt_six), axis=-1) # (6, 100)
    return dot([jt_six, jdes], axes=-1, normalize=True) # (6, 1), need (?, 6)

result = Lambda(dot_similarity)(jdes)

1 个答案:

答案 0 :(得分:0)

您可以直接使用K.dot()。因为您已经使用K.l2_normalize,所以矩阵乘法的结果就是余弦接近度。

from keras.models import Model
import keras.backend as K
from keras.layers import Lambda,Input
import numpy as np

N = 100
def dot_similarity(jdes):
    jdes = K.l2_normalize(jdes, axis=-1) # (?, 100)
    # define it myself
    jt_six = K.constant(np.random.uniform(0, 1, size=(6, N)))
    jt_six = K.l2_normalize(K.variable(jt_six), axis=-1) # (6, 100)
    return K.dot(jdes,K.transpose(jt_six))

jdes = Input(shape=(N,))
result = Lambda(dot_similarity)(jdes)
model = Model(jdes,result)
print(model.summary())

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 100)               0         
_________________________________________________________________
lambda_1 (Lambda)            (None, 6)                 0         
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________