我正在尝试使用Keras Reshape函数API将手套嵌入(4D形状:( ?、 9、20、100))的输出重塑为3D(?,9、2000)。但是,当我尝试Reshape((9,2000))(text_layer)时,会弹出一个错误,说新数组的总大小必须保持不变,即使9 * 20 * 100 = 9 *2000。为什么?附带代码。
text = Input(shape=(9, news_text.shape[1]), name='text')
text_layer = Embedding(
embedding_matrix.shape[0],
embedding_matrix.shape[1],
weights=[embedding_matrix],
input_length=news_text.shape[1]
)(text)
text_layer = Reshape((9, text_layer.shape[2] * text_layer.shape[3]))(text_layer)
答案 0 :(得分:2)
从input_length
层中删除 Embedding
参数。
这很奇怪,我不知道原因,但是当您指定参数input_length
时,就会引发错误。
无论如何,Embedding
层接收Input
层的尺寸。看来参数input_length
具有非常特殊的用途,以便在使用Flatten
层等后知道张量的大小。
在这种情况下,Embedding
层会从输入张量中获取输出张量的形状,而忽略input_length
参数。
(如果设置了无效值,则在添加下一层之前不会引发错误。请注意,input_lenght
和结果shape
):
>>> inp = Input(shape=(9,20))
>>> emb = Embedding(100,100, input_length=84) (inp)
>>> emb
<tf.Tensor 'embedding_5/embedding_lookup:0' shape=(?, 9, 20, 100) dtype=float32>
>>> res = Reshape((9,2000)) (emb)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
...
但是,当您添加input_length
层时,似乎Reshape
参数有冲突。
最后:
text = Input(shape=(9, news_text.shape[1]), name='text')
text_layer = Embedding(
embedding_matrix.shape[0],
embedding_matrix.shape[1],
weights=[embedding_matrix],
)(text)
text_layer = Reshape((9, text_layer.shape[2] * text_layer.shape[3]))(text_layer)