如何连接预训练的嵌入层和输入层

时间:2020-02-13 13:49:25

标签: python tensorflow keras concatenation word-embedding

normal_input = Input(shape=(56,))

pretrained_embeddings = Embedding(num_words, 200, input_length=max_length, trainable=False,
                                                            weights=[ft_embedding_matrix])

concatenated = concatenate([normal_input, pretrained_embeddings])

dense = Dense(256, activation='relu')(concatenated)

我的想法是创建一个256维的输入并将其传递到一个密集层。

我遇到以下错误。

ValueError :使用非符号张量的输入调用了concatenate_10层。收到的类型:。全输入:[,]。该层的所有输入都应为张量。

请帮助我该怎么做。

2 个答案:

答案 0 :(得分:1)

您需要输入以选择要使用的嵌入。

由于您使用了150个单词,所以嵌入的形状为(batch,150,200),无论如何都无法与(batch, 56)串联。您需要以某种方式进行变形以匹配形状。我建议您尝试使用Dense层将56转换为200 ...

word_input = Input((150,))
normal_input = Input((56,))

embedding = pretrained_embeddings(word_input)
normal = Dense(200)(normal_input)

#you could add some normalization here - read below

normal = Reshape((1,200))(normal)
concatenated = Concatenate(axis=1)([normal, embedding]) 

我还建议,由于嵌入和您的输入来自不同的性质,因此您应应用规范化,以使它们变得更相似:

embedding = BatchNormalization(center=False, scale=False)(embedding)
normal = BatchNormalization(center=False, scale=False)(normal)

另一种可能性(我不能说是最好的)是在另一个维度上串联,将56转换为150:

word_input = Input((150,))
normal_input = Input((56,))

embedding = pretrained_embeddings(word_input)
normal = Dense(150)(normal_input)

#you could add some normalization here - read below

normal = Reshape((150,1))(normal)
concatenated = Concatenate(axis=-1)([normal, embedding]) 

我认为这更适合循环和卷积网络,您添加一个新渠道而不是添加一个新步骤。


您甚至可以尝试双重串联,听起来很酷:D

word_input = Input((150,))
normal_input = Input((56,))

embedding = pretrained_embeddings(word_input)
normal150 = Dense(150)(normal_input)
normal201 = Dense(201)(normal_input)

embedding = BatchNormalization(center=False, scale=False)(embedding)
normal150 = BatchNormalization(center=False, scale=False)(normal150)
normal201 = BatchNormalization(center=False, scale=False)(normal201)


normal150 = Reshape((150,1))(normal150)
normal201 = Reshape((1,201))(normal201)
concatenated = Concatenate(axis=-1)([normal150, embedding]) 
concatenated = Concatenate(axis= 1)([normal201, concatenated])

答案 1 :(得分:0)

那是因为连接层被这样称呼:

concatenated = Concatenate()([normal_input, pretrained_embeddings])