我正在尝试根据作为培训数据传递的公司名称列表生成一些文本。我要关注的教程是https://www.tensorflow.org/tutorials/text/text_generation#process_the_text。
当我尝试拟合模型时,我会得到
ValueError:输入0与gru_2层不兼容:预期形状=(1,无,256),找到形状= [64,14,256]
我的数据集维度为<BatchDataset shapes: ((64, 14), (64, 14)), types: (tf.int32, tf.int32)>
使用以下代码创建模型,
BATCH_SIZE = 64
vocab_size = 62
embedding_dim = 256
rnn_units = 2048
def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
model = tf.keras.Sequential([
tf.keras.layers.Embedding(vocab_size, embedding_dim, batch_input_shape=[batch_size, None]),
# tf.keras.layers.Bidirectional(tf.keras.layers.GRU(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform')),
tf.keras.layers.GRU(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
tf.keras.layers.Dense(vocab_size)
])
return model
# model = build_model(vocab_size, embedding_dim, rnn_units, batch_size=BATCH_SIZE)
我的模型摘要如下所示。
Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding_2 (Embedding) (1, None, 256) 15872
_________________________________________________________________
gru_2 (GRU) (1, None, 1024) 3938304
_________________________________________________________________
dense_2 (Dense) (1, None, 62) 63550
=================================================================
Total params: 4,017,726
Trainable params: 4,017,726
Non-trainable params: 0