密集输出层的TensorFlow输入形状错误与model.summary()表示的矛盾

时间:2020-08-22 15:48:04

标签: python tensorflow tf.keras huggingface-transformers

我正在研究NLP问题(句子分类),因此决定将Hu​​ggingFace的TFBertModel与Conv1D,Flatten和Dense层一起使用。我正在使用功能性API,并且我的模型可以编译。但是,在model.fit()期间,在输出Dense层出现形状错误。

模型定义:

# Build model with a max length of 50 words in a sentence
max_len = 50
def build_model():
    bert_encoder = TFBertModel.from_pretrained(model_name)
    input_word_ids = tf.keras.Input(shape=(max_len,), dtype=tf.int32, name="input_word_ids")
    input_mask = tf.keras.Input(shape=(max_len,), dtype=tf.int32, name="input_mask")
    input_type_ids = tf.keras.Input(shape=(max_len,), dtype=tf.int32, name="input_type_ids")
    
    # Create a conv1d model. The model may not really be useful or make sense, but that's OK (for now).
    embedding = bert_encoder([input_word_ids, input_mask, input_type_ids])[0]
    conv_layer = tf.keras.layers.Conv1D(32, 3, activation='relu')(embedding)
    dense_layer = tf.keras.layers.Dense(24, activation='relu')(conv_layer)
    flatten_layer = tf.keras.layers.Flatten()(dense_layer)
    output_layer = tf.keras.layers.Dense(3, activation='softmax')(flatten_layer)
    
    model = tf.keras.Model(inputs=[input_word_ids, input_mask, input_type_ids], outputs=output_layer)
    model.compile(tf.keras.optimizers.Adam(lr=1e-5), loss='sparse_categorical_crossentropy', 
                  metrics=['accuracy'])
    return model

# View model architecture
model = build_model()
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_word_ids (InputLayer)     [(None, 50)]         0                                            
__________________________________________________________________________________________________
input_mask (InputLayer)         [(None, 50)]         0                                            
__________________________________________________________________________________________________
input_type_ids (InputLayer)     [(None, 50)]         0                                            
__________________________________________________________________________________________________
tf_bert_model (TFBertModel)     ((None, 50, 768), (N 177853440   input_word_ids[0][0]             
                                                                 input_mask[0][0]                 
                                                                 input_type_ids[0][0]             
__________________________________________________________________________________________________
conv1d (Conv1D)                 (None, 48, 32)       73760       tf_bert_model[0][0]              
__________________________________________________________________________________________________
dense (Dense)                   (None, 48, 24)       792         conv1d[0][0]                     
__________________________________________________________________________________________________
flatten (Flatten)               (None, 1152)         0           dense[0][0]                      
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 3)            3459        flatten[0][0]                    
==================================================================================================
Total params: 177,931,451
Trainable params: 177,931,451
Non-trainable params: 0
__________________________________________________________________________________________________

# Fit model on input data
model.fit(train_input, train['label'].values, epochs = 3, verbose = 1, batch_size = 16, 
          validation_split = 0.2)

这是错误消息:

ValueError:层密实_1的输入0与该层不兼容:输入形状的预期轴-1具有值1152但已接收 输入形状为[16,6168]

我无法理解如何向layer density_1(输出密集层)输入形状为6168?根据模型摘要,它应该始终为1152。

1 个答案:

答案 0 :(得分:0)

您输入的形状可能不符合您的预期。检查train_input的形状。