我使用LSTM在keras中创建了一个模型,用于预测给定一系列单词的下一个单词.Below是我的相同代码:
# Small LSTM Network to Generate Text for Alice in Wonderland
# load ascii text and covert to lowercase
filename = "wonderland.txt"
raw_text = open(filename).read()
raw_text = raw_text.lower()
print raw_text
# create mapping of unique words to integers
print raw_text
raw_text = re.sub(r'[^\w\s]','',raw_text)
raw_text = re.sub('[^a-z\ \']+', " ", raw_text)
words_unsorted=list(raw_text.split())
words= sorted(list(set(raw_text.split())))
word_to_int = dict((w, i) for i, w in enumerate(words))
int_to_word = dict((i, w) for i, w in enumerate(words))
#print word_to_int
n_words = len(words_unsorted)
n_vocab = len(words)
print "Total Words: ", n_words
print "Total Vocab: ", n_vocab
# prepare the dataset of input to output pairs encoded as integers
seq_length = 7
dataX = []
dataY = []
for i in range(0, n_words - seq_length, 1):
seq_in = words_unsorted[i:i + seq_length]
seq_out = words_unsorted[i + seq_length]
#print seq_in
dataX.append([word_to_int[word] for word in seq_in])
dataY.append(word_to_int[seq_out])
n_patterns = len(dataX)
print "Total Patterns: ", n_patterns
# reshape X to be [samples, time steps, features]
X = numpy.reshape(dataX, (n_patterns, seq_length, 1))
print X[0]
# normalize
X = X / float(n_vocab)
# one hot encode the output variable
y = np_utils.to_categorical(dataY)
# define the LSTM model
model = Sequential()
model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2]), return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(256))
model.add(Dropout(0.2))
model.add(Dense(y.shape[1], activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam')
print model.summary()
# define the checkpoint
filepath="weights-improvement-{epoch:02d}-{loss:.4f}.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True, mode='min')
callbacks_list = [checkpoint]
# fit the model
model.fit(X, y, epochs=50, batch_size=128, callbacks=callbacks_list)
问题是当我在一个测试句子上预测时,我总是得到"和#34;作为下一个单词预测!我应该删除所有停用词还是别的东西?此外,我正在训练它20个时代。
答案 0 :(得分:0)
我很确定,根据帖子的年龄,您已经解决了问题。但以防万一,这是我的2美分。
您最终会预测出最常见的单词。因此,如果删除停用词,则将预测下一个最常见的词。我知道有两种方法可以解决此问题。
首先,您可以使用强调不太频繁的班级(或您的情况下的单词)的损失。这是research paper简介,介绍了焦损,还方便地介绍了github的喀拉拉邦实现。
另一种方法是在fit函数中使用class_weight。
model.fit(X, y, epochs=50, batch_size=128, callbacks=callbacks_list, class_weight=class_weight)
您可以在其中将频率较低的单词的权重设置得较高,例如与频率成反比。