我试图通过使用Keras编写简单的文本摘要生成器来理解序列到序列和LSTM。我正在使用CNN新闻数据集中的一小部分数据集(500篇文章)。我不确定我是否正确这样做。
首先,经过一个小时的跑步后,我的损失极高且准确度低。即使我没有使用图形卡训练模型,我仍然认为准确度非常低。
Epoch 10/10
500/500 [==============================] - 59s 117ms/step - loss: 43815452.8000 - acc: 0.0231
其次,现在我制作了编码器,我正在寻找解码器,但我完全不知道如何生成摘要:当我看到"预测"函数,它返回一个数组(500,1748,62),数组的最后一个维度似乎包含概率?我认为这可能是我的词汇量的概率,这意味着我应该选择最可能的单词,但词汇量有313974个单词,所以我不知道如何解释这个输出。我读了一些文章说它可以使用第二个LSTM进行解码但两个LSTM如何一起通信?我应该从第一个状态传递状态吗?
如果有人能告诉我我做错了什么并解释,那将非常感激:)。
以下是我用于编码器的代码:
#Some informations
#VOCAB_SIZE = 313974
#np.shape(x_train) = (500,1748) and contains index of the words
#np.shape(word_embedding_matrix) = (313974, 300) and contains my words in a 300 dimensions space pre-processed with GloVe.
def create_model():
prob=0.5
EMBEDDING_DIM = 300
VOCAB_SIZE = nb_words
embedding_layer = Embedding( VOCAB_SIZE,
EMBEDDING_DIM,
weights=[word_embedding_matrix],
trainable=False)
LAYER_NUM = 3
HIDDEN_DIM = 300
model = Sequential()
model.add(embedding_layer)
for i in range(LAYER_NUM - 1):
model.add(LSTM(y_train.shape[1], return_sequences=True))
model.add(Dropout(prob, name="drop_%s" %(i)))
LSTM(y_train.shape[1], return_sequences=False)
model.add(Dense(y_train.shape[1]))
model.add(Activation('softmax'))
model.compile(loss="categorical_crossentropy", optimizer="rmsprop", metrics=['accuracy'])
return model
model = create_model()
y_train = np.reshape(y_train, (y_train.shape[0], 1, y_train.shape[1]))
nb_epoch = 0
while True:
print('\n')
model.fit(x_train, y_train, batch_size=BATCH_SIZE, verbose=1, nb_epoch=10)
nb_epoch += 1
模型摘要:
Layer (type) Output Shape Param #
=================================================================
embedding_4 (Embedding) (None, None, 300) 94192200
_________________________________________________________________
lstm_6 (LSTM) (None, None, 62) 90024
_________________________________________________________________
drop_0 (Dropout) (None, None, 62) 0
_________________________________________________________________
lstm_7 (LSTM) (None, None, 62) 31000
_________________________________________________________________
drop_1 (Dropout) (None, None, 62) 0
_________________________________________________________________
dense_2 (Dense) (None, None, 62) 3906
_________________________________________________________________
activation_2 (Activation) (None, None, 62) 0
=================================================================
Total params: 94,317,130
Trainable params: 124,930
Non-trainable params: 94,192,200