我正在尝试训练Keras测试摘要模型,以生成新闻文章的新标题,然后将其与已发布的标题进行比较。我正在使用GloVe 6B进行培训,然后根据该文章进行预测,该文章已通过删除标签,解析,删除停用词,词形修饰然后重新加入而进行了清理。我的结果看起来像这样:
最初的头条新闻:福特汽车尽管据称担心飞行,但仍于八月份前往马里兰 生成的标题:意见:要访问的
清洗后的文章文字:布雷特·卡瓦诺(Brett Kavanaugh)原告克里斯蒂娜·布拉西·福特(Christine Blasey Ford)尽管害怕飞行,但还是接受了马里兰州的测谎仪测试。福特公司于8月7日对测谎仪进行了测验。福特·福特8月7日,马里兰州巴尔的摩的希尔顿酒店巴尔的摩华盛顿国际机场福特享受飞行艰难时光的逃生路线克里斯蒂娜•布拉西•福特教授指控最高法院提名人布雷特•卡瓦诺曾性侵犯高中曾告诉朋友所谓的遭遇30年前持久影响生活两位长期朋友福特告诉CNN周前曾表示感到不自在挣扎封闭的空间逃生路线出口门暗示称不适遇到了所谓的卡瓦诺(Kavanaugh)的原因福特享受飞行DeVarney说飞机最终封闭了空间恐惧飞行的福特能够及时地作证参议院司法机构在2018年7月30日签署的加利福尼亚民主党参议员戴安·费恩斯坦福特表示,福特将于8月7日在大西洋中部度假,并于星期四美国东部时间上午10点作证。
这是我的训练代码:
from __future__ import print_function
import pandas as pd
from sklearn.model_selection import train_test_split
from keras_text_summarization.library.utility.plot_utils import plot_and_save_history
from keras_text_summarization.library.seq2seq import Seq2SeqGloVeSummarizer
from keras_text_summarization.library.applications.fake_news_loader import fit_text
import numpy as np
LOAD_EXISTING_WEIGHTS = False
def main():
np.random.seed(42)
data_dir_path = './data'
very_large_data_dir_path = './very_large_data'
report_dir_path = './reports'
model_dir_path = './models'
print('loading csv file ...')
df = pd.read_csv("dcr Man_Cleaned.csv")
print('extract configuration from input texts ...')
Y = df.Title
X = df['Joined']
config = fit_text(X, Y)
print('configuration extracted from input texts ...')
summarizer = Seq2SeqGloVeSummarizer(config)
summarizer.load_glove(very_large_data_dir_path)
if LOAD_EXISTING_WEIGHTS:
summarizer.load_weights(weight_file_path=Seq2SeqGloVeSummarizer.get_weight_file_path(model_dir_path=model_dir_path))
Xtrain, Xtest, Ytrain, Ytest = train_test_split(X, Y, test_size=0.2, random_state=42)
print('training size: ', len(Xtrain))
print('testing size: ', len(Xtest))
print('start fitting ...')
history = summarizer.fit(Xtrain, Ytrain, Xtest, Ytest, epochs=20, batch_size=16)
history_plot_file_path = report_dir_path + '/' + Seq2SeqGloVeSummarizer.model_name + '-history.png'
if LOAD_EXISTING_WEIGHTS:
history_plot_file_path = report_dir_path + '/' + Seq2SeqGloVeSummarizer.model_name + '-history-v' + str(summarizer.version) + '.png'
plot_and_save_history(history, summarizer.model_name, history_plot_file_path, metrics={'loss', 'acc'})
if __name__ == '__main__':
main()
任何对这里出问题的想法都表示赞赏。
答案 0 :(得分:0)
好,所以这是一个非常广泛的问题,很多事情都可能出错。这里的问题是由于多种原因,您的模型卡在了类中:
无论哪种方式,您都应该尝试添加一些层,删除一些层,添加一些正则化,然后尝试轮盘赌法根据单词概率生成下一个单词。希望这会有所帮助:)