我试图在Keras中编写一个加权嵌入层,该层采用权重矩阵表示输入中的每个时间步,并输出其中的所有嵌入向量的加权和。
我正在使用Keras GitHub存储库中的嵌入层实现,因此我想更改调用函数的实现,并将gather
行替换为另一个矩阵操作操作。
def call(self, inputs):
if K.dtype(inputs) != 'int32':
inputs = K.cast(inputs, 'int32')
out = K.gather(self.embeddings, inputs) # Change this
return out
有什么想法吗?