AttributeError:'嵌入'对象没有属性' get_shape'使用TensorFlow后端

时间:2017-05-31 13:17:07

标签: python tensorflow keras

我试图了解嵌入层如何与掩蔽一起工作(序列到序列回归)。

此简单代码失败,错误为:AttributeError: 'Embedding' object has no attribute 'get_shape'。这似乎是真的,但我不知道如何解决它。任何提示?

import numpy as np
from keras.layers import Input, Dense, LSTM
from keras.layers.embeddings import Embedding
from keras.layers.merge import Concatenate
from keras.models import Model
from keras.utils import plot_model

trainExs = np.asarray([ [1, 2, 3], [2, 3, 1]])
trainLabels = np.asarray([[1, 1, 1], [2, 2, 2]])

print('Examples, shape:', trainExs.shape)
print('Labels, shape:', trainLabels.shape)

W = [[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]
symDim = 3

# E M B E D D I N G S
# symbol_in = Input(shape=(None, 1), dtype='float32', name='symbol_input')
symbol_emb = Embedding(symDim+1, symDim,
                       weights=np.asarray(W), trainable=False, input_length=3)

symbol_dense = Dense(symDim, use_bias=True, name='symbol_dense')(symbol_emb)

output_layer = Dense(symDim, dtype='float32', name='output')(symbol_dense)

# M O D E L
model = Model(inputs=[symbol_emb], outputs=[output_layer])
model.compile(loss='mean_squared_error', optimizer='RMSprop', metrics=['accuracy'])
# print(model.summary())

完整的堆栈跟踪如下:

D:\python\python.exe D:/workspace/TESTS/test/testEMb.py
Using TensorFlow backend.
Examples, shape: (2, 3)
Labels, shape: (2, 3)
Traceback (most recent call last):
  File "D:/workspace/TESTS/test/testEMb.py", line 21, in <module>
    symbol_dense = Dense(symDim, use_bias=True, name='symbol_dense')(symbol_emb)
  File "D:\python\lib\site-packages\keras\engine\topology.py", line 541, in __call__
    self.assert_input_compatibility(inputs)
  File "D:\python\lib\site-packages\keras\engine\topology.py", line 450, in assert_input_compatibility
    ndim = K.ndim(x)
  File "D:\python\lib\site-packages\keras\backend\tensorflow_backend.py", line 479, in ndim
    dims = x.get_shape()._dims
AttributeError: 'Embedding' object has no attribute 'get_shape'

1 个答案:

答案 0 :(得分:2)

您正在向模型提供symbol_emb作为输入,但symbol_emb是嵌入图层的名称,并非有效输入。定义输入,例如:

input = Input(shape=input_shape)
symbol_emb = Embedding(symDim+1, symDim,
                       weights=np.asarray(W), trainable=False)(input)

...
...

model = Model(inputs=[input], outputs=[output_layer])

请注意,您无需以input_length这种方式定义Embedding