Keras:密集与嵌入 - ValueError:输入0与层repeat_vector_9不兼容:预期ndim = 2,发现ndim = 3

时间:2017-11-18 07:32:01

标签: keras embedding keras-layer word-embedding keras-2

我有以下网络可以正常工作:

# Generate by a command
send [format "%c" 4]

但是,我需要用嵌入层替换Dense图层:

left = Sequential()
left.add(Dense(EMBED_DIM,input_shape=(ENCODE_DIM,)))
left.add(RepeatVector(look_back))

然后当我使用嵌入层时出现以下错误:

left = Sequential()
left.add(Embedding(ENCODE_DIM, EMBED_DIM, input_length=1))
left.add(RepeatVector(look_back))

使用嵌入层替换Dense图层时需要进行哪些其他更改?谢谢!

1 个答案:

答案 0 :(得分:3)

Dense图层的输出形状为(None, EMBED_DIM)。但是,Embedding图层的输出形状为(None, input_length, EMBED_DIM)。使用input_length=1,它将是(None, 1, EMBED_DIM)。您可以在Flatten图层后添加Embedding图层以移除轴1。

您可以打印输出形状以调试模型。例如,

EMBED_DIM = 128
left = Sequential()
left.add(Dense(EMBED_DIM, input_shape=(ENCODE_DIM,)))
print(left.output_shape)
(None, 128)

left = Sequential()
left.add(Embedding(ENCODE_DIM, EMBED_DIM, input_length=1))
print(left.output_shape)
(None, 1, 128)

left.add(Flatten())
print(left.output_shape)
(None, 128)