在尝试根据新闻报道预测主题时,我遇到了几个问题。新闻文章已清除(无字样,数字,...)。有6个类别,每个类别有13000个新闻文章的数据集(数据集的均匀分布)。
预处理:
stop_words = set(stopwords.words('english'))
for index, row in data.iterrows():
print ("Index: ", index)
txt_clean = ' '.join(re.sub("([^a-zA-Z ])", " ", data.loc[index,'txt_clean']).split()).lower()
word_tokens = word_tokenize(txt_clean)
filtered_sentence = [w for w in word_tokens if not w in stop_words]
cleaned_text = ''
for w in filtered_sentence:
cleaned_text = cleaned_text + ' ' + w
data.loc[index,'txt_clean'] = cleaned_text
我使用LSTM如下实现了RNN:
model = Sequential()
model.add(Embedding(50000, 100, input_length=500))
model.add(SpatialDropout1D(0.2))
model.add(LSTM(150, dropout=0.2, recurrent_dropout=0.2))
model.add(Dense(6, activation='softmax'))
model.summary()
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
history = model.fit(X_train, Y_train, epochs=epochs, batch_size=batch_size, validation_split=0.1)
accr = model.evaluate(X_test,Y_test)
print('Test set\n Loss: {:0.3f}\n Accuracy: {:0.3f}'.format(accr[0],accr[1]))
预测:
model = load_model('model.h5')
data = data.sample(n=15000)
model.compile(optimizer = 'rmsprop', loss = 'categorical_crossentropy', metrics = ['accuracy'])
tokenizer = Tokenizer(num_words=50000)
tokenizer.fit_on_texts(data['txt_clean'].values) (Prediction data sample values and not the same as in the training))
CATEGORIES = ['A','B','C','D','E','F']
for index, row in data.iterrows():
seq = tokenizer.texts_to_sequences([data.loc[index,'txt_clean']])
padded = pad_sequences(seq, maxlen=500)
pred = model.predict(padded)
pred = pred[0]
print (pred, pred[np.argmax(pred)]))
例如,在10个历元之后,并且batch_size为500:
还尝试将batch_size的数量减少到64:
在我看来,使用64批大小的结果似乎更好,但是当我(逐一)预测新闻时,我的准确度为15.97%。与训练和测试相比,这种预测的准确性要低得多。
可能是什么问题?
谢谢!
答案 0 :(得分:0)
这是ML或DL中存在的经典问题。可能有两个原因
答案 1 :(得分:0)
请尝试使用pickle或joblib来对令牌生成器进行腌制,以保存您的keras令牌生成器,将其用于训练和预测。
这是保存keras令牌生成器的示例代码:-
import pickle
# saving
with open('tokenizer.pickle', 'wb') as handle:
pickle.dump(tokenizer, handle, protocol=pickle.HIGHEST_PROTOCOL)
# loading
with open('tokenizer.pickle', 'rb') as handle:
tokenizer = pickle.load(handle)