用于文本分类的Tenserflow模型无法按预期进行预测?

时间:2020-02-01 10:01:07

标签: python tensorflow machine-learning keras

我正在尝试训练一个用于情感分析的模型,将数据分为训练和测试时,它显示90%的准确性!但是,每当我在一个新短语上对其进行测试时,结果几乎都相同(通常在0.86-0.95范围内)! 这是代码:

sentences = data['text'].values.astype('U')
y = data['label'].values

sentences_train, sentences_test, y_train, y_test = train_test_split(sentences, y, test_size=0.2, random_state=1000)

tokenizer = Tokenizer(num_words=5000)
tokenizer.fit_on_texts(sentences_train)

X_train = tokenizer.texts_to_sequences(sentences_train)
X_test = tokenizer.texts_to_sequences(sentences_test)

vocab_size = len(tokenizer.word_index) + 1  

maxlen = 100

X_train = pad_sequences(X_train, padding='post', maxlen=maxlen)
X_test = pad_sequences(X_test, padding='post', maxlen=maxlen)

embedding_dim = 50
model = Sequential()
model.add(Embedding(input_dim=vocab_size, output_dim=embedding_dim,input_length=maxlen))
model.add(layers.Flatten())
model.add(layers.Dense(10, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy'])
model.summary()

history = model.fit(X_train, y_train,
                    epochs=5,
                    verbose=True,
                    validation_data=(X_test, y_test),
                    batch_size=10)
loss, accuracy = model.evaluate(X_train, y_train, verbose=False)
print("Training Accuracy: {:.4f}".format(accuracy))
loss, accuracy = model.evaluate(X_test, y_test, verbose=False)
print("Testing Accuracy:  {:.4f}".format(accuracy))

训练数据是一个包含3个列的CSV文件:(id,文本,标签(0,1)),其中0为正数,1为负数。

Training Accuracy: 0.9855
Testing Accuracy:  0.9013

在新句子(例如“这只是文本!”)上进行测试。和“讨厌的传教士!”会预测相同的结果[0.85],[0.83]。

1 个答案:

答案 0 :(得分:1)

您似乎是<script id="source" type="worker"> importScripts("https://greggman.github.io/doodles/test/ping-worker.js"); </script>的受害者。换句话说,我们的模型将过度适合overfitting。尽管通常可以在训练集上达到高精度,但是像您的情况一样,我们真正想要的是开发能够很好地归纳为测试数据(或没见过)。

您可以按照these的步骤进行操作,以防止过度拟合。

此外,为了提高算法性能,建议您增加training data层的 neurons 数量,并设置更多Dense以提高的性能。用于测试新数据的算法。