来自语料库的单词预测

时间:2018-08-23 09:05:52

标签: python-3.x python-2.7 keras lstm recurrent-neural-network

我正在尝试建立一个模型,在该模型中,随着用户的输入,我的模型应该预测单词(在这种情况下为药品名称)。我有一个将要键入的所有单词的语料库。当我尝试使用rnn和lstm来预测这些单词时,我得到的单词与语料库中出现的单词相似,但与实际单词不一样。

例如: 用户类型:流感

预期产量:氟伏沙明

我得到的是:{'re','cobact','nate','diclox','te-cv'}

与药物语料库中存在的许多名称类似,但没有以“ flure”,“ flucobact”,...的名称存在。

我要去哪里错了,应该进行哪些更改才能获得所需的输出。 我用于执行此操作的代码是:

# Setup
import numpy as np
import tensorflow as tf
from numpy.core.multiarray import dtype
from keras.models import Sequential, load_model
from keras.layers import Dense, Activation
from keras.layers import LSTM, Dropout
from keras.layers import TimeDistributed
from keras.layers.core import Dense, Activation, Dropout, RepeatVector
from keras.optimizers import RMSprop
import matplotlib

matplotlib.use('agg')
import matplotlib.pyplot as plt
import pickle
import sys
import heapq
import seaborn as sns
from pylab import rcParams

np.random.seed(42)
tf.set_random_seed(42)

sns.set(style='whitegrid', palette='muted', font_scale=1.5)
rcParams['figure.figsize'] = 12, 5

# Loading the data
# path = 'nietzsche.txt'
# path = 'nietzschemod.txt'
path = 'BrandedProductNames.txt'
text = open(path).read().lower()
print ('Corpus length: ', len(text))

# Preprocessing
# Finding all the unique characters in the corpus
chars = sorted(list(set(text)))
char_indices = dict((c, i) for i, c in enumerate(chars))
indices_char = dict((i, c) for i, c in enumerate(chars))

print ("unique chars: ", len(chars))

# Cutting the corpus into chunks of 40 chars, spacing the sequences by 3 characters
# We will additionally store the next character (the one we need to predict) for every sequence

SEQUENCE_LENGTH = 3
step = 3
sentences = []
next_chars = []
for i in range(0, len(text) - SEQUENCE_LENGTH, step):
    sentences.append(text[i:i + SEQUENCE_LENGTH])
    next_chars.append(text[i + SEQUENCE_LENGTH])
print ('num training examples: ', len(sentences))

# Generating features and labels.
# Using previously generated sequences and characters that need to be predicted to create one-hot encoded vectors

X = np.zeros((len(sentences), SEQUENCE_LENGTH, len(chars)), dtype=np.bool)
y = np.zeros((len(sentences), len(chars)), dtype=np.bool)
for i, sentence in enumerate(sentences):
    for t, char in enumerate(sentence):
        X[i, t, char_indices[char]] = 1
    y[i, char_indices[next_chars[i]]] = 1

# """
# Building the model

model = Sequential()
model.add(LSTM(128, input_shape=(SEQUENCE_LENGTH, len(chars))))
model.add(Dense(len(chars)))
model.add(Activation('softmax'))

# Training
optimizer = RMSprop(lr=0.01)
model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])

# history = model.fit(X, y, validation_split=0.05, batch_size=128, epochs=20, shuffle=True).history
history = model.fit(X, y, validation_split=0.05, batch_size=128, epochs=1000, shuffle=True).history

# Saving
# model.save('keras_model'+str(SEQUENCE_LENGTH)+'.h5')
model.save('Med_keras_model' + str(SEQUENCE_LENGTH) + '.h5')
# pickle.dump(history, open('history'+str(SEQUENCE_LENGTH)+'.p', 'wb'))
pickle.dump(history, open('Med_history' + str(SEQUENCE_LENGTH) + '.p', 'wb'))
# """

# Loading back the saved weights and history

# model = load_model('keras_model'+str(SEQUENCE_LENGTH)+'.h5')
# history = pickle.load(open('history'+str(SEQUENCE_LENGTH)+'.p', 'rb'))
model = load_model('Med_keras_model' + str(SEQUENCE_LENGTH) + '.h5')
history = pickle.load(open('Med_history' + str(SEQUENCE_LENGTH) + '.p', 'rb'))

# Evaluation
plt.plot(history['acc'])
plt.plot(history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')

plt.show()
plt.savefig("01.Accuracy.png")

plt.plot(history['loss'])
plt.plot(history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')

plt.show()
plt.savefig("02.Loss.png")


# Testing
def prepare_input(text):
    x = np.zeros((1, SEQUENCE_LENGTH, len(chars)))
    for t, char in enumerate(text):
        x[0, t, char_indices[char]] = 1
    return x


# The sequences must be 40 chars long and the tensor is of the shape (1, 40, 57)


# The sample function
# This function allows us to ask our model what are the next probable characters (The heap simplifies the job)
def sample(preds, top_n=3):
    preds = np.asarray(preds).astype('float64')
    preds = np.log(preds)
    exp_preds = np.exp(preds)
    preds = exp_preds / np.sum(exp_preds)
    return heapq.nlargest(top_n, range(len(preds)), preds.take)


# Prediction function
def predict_completion(text):
    # print "From predict_completion: Text= "+text
    original_text = text
    completion = ''
    while True:
        x = prepare_input(text)
        preds = model.predict(x, verbose=0)[0]
        next_index = sample(preds, top_n=1)[0]
        next_char = indices_char[next_index]

        text = text[1:] + next_char
        # print ("From predict_completion: Text after appending= " + text)
        completion += next_char
        # print ("From predict_completion: completion= " + completion)
        # print (str(len(original_text + completion) + 2),str(len(original_text)), str(next_char == ' '), next_char)
        print (next_char)
        if len(original_text + completion) + 2 > len(original_text) and (
                next_char == ' ' or next_char == '\n' or next_char == '\r' or next_char == '\t'):
            return completion


# This methods wraps everything and allows us to predict multiple completions
def predict_completions(text, n=3):
    x = prepare_input(text)
    # print "Prepared input for "+text
    preds = model.predict(x, verbose=0)[0]
    # print "Prediction Done for "+text
    next_indices = sample(preds, n)
    # print "Created next indices "+text
    return [indices_char[idx] + predict_completion(text[1:] + indices_char[idx]) for idx in next_indices]


inputtxt = [
    "abacavir",
    "acemetacin",
    "dasatinib",
    "erythropoietin",
    "fluvoxamine"
]

for txt in inputtxt:
    seq = txt[:SEQUENCE_LENGTH].lower()
    print (seq)
    print (predict_completions(seq, 5))
    print ()

0 个答案:

没有答案