我有一个jdes
的张量(?, 100)
和一个形状为jt_six
的常数矩阵(6,100)
。而且,我试图得到jdes
和jt_six
的每一行的余弦接近度的结果,并且结果的形状应为(?, 6)
。我看到dot()
层能够计算余弦接近度设置normalize=True
,但是有了我的代码,我得到了形状为(6,1)
的结果,其中没有批量大小。有人可以帮我吗?
def dot_similarity(jdes):
jdes = K.l2_normalize(jdes, axis=-1) # (?, 100)
jt_six = K.l2_normalize(K.variable(jt_six), axis=-1) # (6, 100)
return dot([jt_six, jdes], axes=-1, normalize=True) # (6, 1), need (?, 6)
result = Lambda(dot_similarity)(jdes)
答案 0 :(得分:0)
您可以直接使用K.dot()
。因为您已经使用K.l2_normalize
,所以矩阵乘法的结果就是余弦接近度。
from keras.models import Model
import keras.backend as K
from keras.layers import Lambda,Input
import numpy as np
N = 100
def dot_similarity(jdes):
jdes = K.l2_normalize(jdes, axis=-1) # (?, 100)
# define it myself
jt_six = K.constant(np.random.uniform(0, 1, size=(6, N)))
jt_six = K.l2_normalize(K.variable(jt_six), axis=-1) # (6, 100)
return K.dot(jdes,K.transpose(jt_six))
jdes = Input(shape=(N,))
result = Lambda(dot_similarity)(jdes)
model = Model(jdes,result)
print(model.summary())
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 100) 0
_________________________________________________________________
lambda_1 (Lambda) (None, 6) 0
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________