RNN:训练模型后从文本输入中获取预测

时间:2018-08-25 17:49:51

标签: python tensorflow machine-learning keras text-classification

我是RNN的新手,我一直在研究小型二进制标签分类器。我已经能够获得满意结果的稳定模型。

但是,我很难使用该模型对新输入进行分类,我想知道你们中的任何人是否可以帮助我。请参阅下面的代码以供参考。

非常感谢您。

from tensorflow.keras import preprocessing
from sklearn.utils import shuffle
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from tensorflow.keras.models import Model
from tensorflow.keras import models
from tensorflow.keras.layers import LSTM, Activation, Dense, Dropout, Input, 
Embedding
from tensorflow.keras.optimizers import RMSprop, Adam
from tensorflow.keras.preprocessing import sequence, text
from tensorflow.keras.callbacks import EarlyStopping
from matplotlib import pyplot

class tensor_rnn():
def __init__(self, hidden_layers=3):
    self.data_path = 'C:\\\\Users\\cmazz\\PycharmProjects\\InvestmentAnalysis_2.0\\Sentiment\\Finance_Articles\\'
    # self.corp_paths = corpora_paths
    self.h_layers = hidden_layers
    self.num_words = []
    good = pd.read_csv(self.data_path + 'GoodO.csv')
    good['Polarity'] = 'pos'
    for line in good['Head'].tolist():
        counter = len(line.split())
        self.num_words.append(counter)
    bad = pd.read_csv(self.data_path + 'BadO.csv')
    bad['Polarity'] = 'neg'
    for line in bad['Head'].tolist():
        counter = len(line.split())
        self.num_words.append(counter)
    self.features = pd.concat([good, bad]).reset_index(drop=True)
    self.features = shuffle(self.features)

    self.max_len = len(max(self.features['Head'].tolist()))
    # self.train, self.test = train_test_split(features, test_size=0.33, random_state=42)
    X = self.features['Head']
    Y = self.features['Polarity']
    le = LabelEncoder()
    Y = le.fit_transform(Y)
    Y = Y.reshape(-1, 1)
    self.X_train, self.X_test, self.Y_train, self.Y_test = train_test_split(X, Y, test_size=0.30)
    self.tok = preprocessing.text.Tokenizer(num_words=len(self.num_words))
    self.tok.fit_on_texts(self.X_train)
    sequences = self.tok.texts_to_sequences(self.X_train)
    self.sequences_matrix = preprocessing.sequence.pad_sequences(sequences, maxlen=self.max_len)

def RNN(self):
    inputs = Input(name='inputs', shape=[self.max_len])
    layer = Embedding(len(self.num_words), 30, input_length=self.max_len)(inputs)
    # layer = LSTM(64, return_sequences=True)(layer)
    layer = LSTM(32)(layer)
    layer = Dense(256, name='FC1')(layer)
    layer = Activation('relu')(layer)
    layer = Dropout(0.5)(layer)
    layer = Dense(1, name='out_layer')(layer)
    layer = Activation('sigmoid')(layer)
    model = Model(inputs=inputs, outputs=layer)
    return model

def model_train(self):
    self.model = self.RNN()
    self.model.summary()
    self.model.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy'])   # RMSprop()

def model_test(self):
    self.history = self.model.fit(self.sequences_matrix, self.Y_train, batch_size=100, epochs=3,
              validation_split=0.30, callbacks=[EarlyStopping(monitor='val_loss', min_delta=0.0001)])
    test_sequences = self.tok.texts_to_sequences(self.X_test)
    test_sequences_matrix = sequence.pad_sequences(test_sequences, maxlen=self.max_len)
    accr = self.model.evaluate(test_sequences_matrix, self.Y_test)
    print('Test set\n  Loss: {:0.3f}\n  Accuracy: {:0.3f}'.format(accr[0], accr[1]))


if __name__ == "__main__":
    a = tensor_rnn()
    a.model_train()
    a.model_test()
    a.model.save('C:\\\\Users\\cmazz\\PycharmProjects\\'
                              'InvestmentAnalysis_2.0\\RNN_Model.h5', 
    include_optimizer=True)
     b = models.load_model('C:\\\\Users\\cmazz\\PycharmProjects\\'
                              'InvestmentAnalysis_2.0\\RNN_Model.h5')
    stringy = ['Fund managers back away from Amazon as they cut FANG exposure']
    prediction = b.predict(np.array(stringy))
    print(prediction)

运行代码时,出现以下错误:

  

ValueError:检查输入时出错:预期输入具有形状   (39,)但形状为(1,)

的数组

1 个答案:

答案 0 :(得分:3)

基于ValueError和prediction = b.predict(np.array(stringy)),我认为您需要标记输入字符串。