在python中使用word_tokenize时出现keyError

时间:2019-08-09 02:42:27

标签: python machine-learning keras

我正在尝试使用keras和IMDB数据集运行情感分析问题,但是当我尝试对文本进行标记时,会遇到关键错误

import numpy as np
from keras.datasets import imdb
import json
(X_train, y_train), (X_test, y_test) = imdb.load_data(num_words=5000)
from keras.preprocessing import sequence
X_train = sequence.pad_sequences(X_train, maxlen = 500)
X_test = sequence.pad_sequences(X_test, maxlen = 500)

from keras import Sequential
from keras.layers import Embedding, LSTM, Dense, Dropout

model = Sequential()
model.add(Embedding(5000, 32, input_length = 500))
model.add(LSTM(units = 100))
model.add(Dense(1, activation = 'sigmoid'))
print(model.summary())
model.compile(loss = 'binary_crossentropy', optimizer = 'adam', metrics=['accuracy'])
model.fit(X_train, y_train, batch_size = 64, epochs = 6)

good = "A great movie"
bad = "This was not a great movie"

from nltk import word_tokenize
from keras.preprocessing import sequence

word2index = imdb.get_word_index()
X=[]
for word in word_tokenize(good):
     X.append(word2index[word])
X=sequence.pad_sequences([X],maxlen=500)

loaded_model.predict(X)

由于我遇到错误,它没有进入预测部分

 File "<ipython-input-51-9268dcdfa83f>", line 9, in <module>
    test.append(word2index[word])

KeyError: 'A'

我该怎么解决这个问题?

2 个答案:

答案 0 :(得分:0)

优良作法是避免使用大写单词,并在字符串上使用.lower()

您的字符Aword2index字典中不存在,但a存在。您会注意到word2index中的每个元素都是小写字母。

因此,如果您执行X.append(word2index[word.lower()]),则应该获得适当的结果。

答案 1 :(得分:0)

首先,仅使用以下代码即可重现错误:

from keras.datasets import imdb
from nltk import word_tokenize

good = "A great movie"

word2index = imdb.get_word_index()
X=[]
for word in word_tokenize(good):
     X.append(word2index[word])

print(X)

问题在于word字典中没有A = word2index。如Tensorflow/Keras documentation中所述,方法get_word_index() returns a dictionary是从 imdb_word_index.json 转换的。

您可以检查该词典的内容。如API文档中所述,imdb数据集已本地下载到〜/ .keras / datasets / imdb_word_index.json 。首次访问时会自动下载,也可以从https://s3.amazonaws.com/text-datasets/imdb_word_index.json手动下载。

您会看到 imdb_word_index.json 均为小写字符。因此,生成的word2index字典也将具有全部小写的键。

解决方案是在检查该数据集时也使用小写字母。

for word in word_tokenize(good.lower()):
     X.append(word2index[word])

print(X)  # [3, 84, 17]