Keras Reshape:新数组的总大小必须保持不变

时间:2018-08-10 23:18:25

标签: tensorflow keras deep-learning lstm reshape

我正在尝试使用Keras Reshape函数API将手套嵌入(4D形状:( ?、 9、20、100))的输出重塑为3D(?,9、2000)。但是,当我尝试Reshape((9,2000))(text_layer)时,会弹出一个错误,说新数组的总大小必须保持不变,即使9 * 20 * 100 = 9 *2000。为什么?附带代码。

text = Input(shape=(9, news_text.shape[1]), name='text')
text_layer = Embedding(
    embedding_matrix.shape[0],
    embedding_matrix.shape[1],
    weights=[embedding_matrix],
    input_length=news_text.shape[1]
)(text)
text_layer = Reshape((9, text_layer.shape[2] * text_layer.shape[3]))(text_layer)

1 个答案:

答案 0 :(得分:2)

input_length层中删除 Embedding 参数。

这很奇怪,我不知道原因,但是当您指定参数input_length时,就会引发错误。

无论如何,Embedding层接收Input层的尺寸。看来参数input_length具有非常特殊的用途,以便在使用Flatten层等后知道张量的大小。

在这种情况下,Embedding层会从输入张量中获取输出张量的形状,而忽略input_length参数。

(如果设置了无效值,则在添加下一层之前不会引发错误。请注意,input_lenght和结果shape):

>>> inp = Input(shape=(9,20))
>>> emb = Embedding(100,100, input_length=84) (inp)
>>> emb
<tf.Tensor 'embedding_5/embedding_lookup:0' shape=(?, 9, 20, 100) dtype=float32>
>>> res = Reshape((9,2000)) (emb)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  ...

但是,当您添加input_length层时,似乎Reshape参数有冲突。

最后:

text = Input(shape=(9, news_text.shape[1]), name='text')
text_layer = Embedding(
    embedding_matrix.shape[0],
    embedding_matrix.shape[1],
    weights=[embedding_matrix],
)(text)
text_layer = Reshape((9, text_layer.shape[2] * text_layer.shape[3]))(text_layer)