在线LSTM分类模型可提供大量错误的预测

时间:2018-09-26 05:47:48

标签: python tensorflow machine-learning lstm text-classification

我正在尝试使用20个新闻组数据集实施在线分类模型,以将帖子分类为相关组。

预处理:我要遍历所有帖子,并用这些单词制作字典。然后,我将从1开始的单词编入索引。然后依次遍历所有帖子和每个单词在帖子中,我正在搜索词汇表并将相关的索引号放入数组中。然后,我通过在末尾放置0来填充所有数组,以使它们的大小都相同(6577)。

然后我要创建一个嵌入层(嵌入大小= 300)。并且每个输入将先经过此嵌入式层,然后再馈送到LSTM层(LSTM输入shape =(1,6577,300))。

在我的模型中,我有一个LSTM层(大小= 200)和一个隐藏层(大小= 25)。我为此在tensorflow中使用dynamic_rnn单元格,并将序列长度参数设置为帖子的实际长度(没有填充0s的长度)以避免分析填充0s。然后,从LSTM层的输出中,我仅将相关的输出馈送到隐藏层。

从那里开始,它就像一个普通的LSTM实现。我已经尽我所能来提高模型的准确性,但是错误预测的次数非常多:

  

数据点数:18,846
   错误:17876
  错误率:0.9485301920832007

注意:在向后传播期间,我正在训练嵌入式层和隐藏层。

问题:我想知道我在这里做错了什么,或者有什么想法可以改进模型。预先谢谢你。

我的完整代码如下所示:

from collections import Counter
import tensorflow as tf
from sklearn.datasets import fetch_20newsgroups
import matplotlib as mplt
mplt.use('agg') # Must be before importing matplotlib.pyplot or pylab!
import matplotlib.pyplot as plt
from string import punctuation
from sklearn.preprocessing import LabelBinarizer
import numpy as np
from nltk.corpus import stopwords
import nltk
nltk.download('stopwords')



def pre_process():
    newsgroups_data = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))

    words = []
    temp_post_text = []
    print(len(newsgroups_data.data))

    for post in newsgroups_data.data:

        all_text = ''.join([text for text in post if text not in punctuation])
        all_text = all_text.split('\n')
        all_text = ''.join(all_text)
        temp_text = all_text.split(" ")

        for word in temp_text:
            if word.isalpha():
                temp_text[temp_text.index(word)] = word.lower()

        # temp_text = [word for word in temp_text if word not in stopwords.words('english')]
        temp_text = list(filter(None, temp_text))
        temp_text = ' '.join([i for i in temp_text if not i.isdigit()])
        words += temp_text.split(" ")
        temp_post_text.append(temp_text)

    # temp_post_text = list(filter(None, temp_post_text))

    dictionary = Counter(words)
    # deleting spaces
    # del dictionary[""]
    sorted_split_words = sorted(dictionary, key=dictionary.get, reverse=True)
    vocab_to_int = {c: i for i, c in enumerate(sorted_split_words,1)}

    message_ints = []
    for message in temp_post_text:
        temp_message = message.split(" ")
        message_ints.append([vocab_to_int[i] for i in temp_message])


    # maximum message length = 6577

    # message_lens = Counter([len(x) for x in message_ints])AAA

    seq_length = 6577
    num_messages = len(temp_post_text)
    features = np.zeros([num_messages, seq_length], dtype=int)
    for i, row in enumerate(message_ints):
        # print(features[i, -len(row):])
        # features[i, -len(row):] = np.array(row)[:seq_length]
        features[i, :len(row)] = np.array(row)[:seq_length]
        # print(features[i])

    lb = LabelBinarizer()
    lbl = newsgroups_data.target
    labels = np.reshape(lbl, [-1])
    labels = lb.fit_transform(labels)

    sequence_lengths = [len(msg) for msg in message_ints]
    return features, labels, len(sorted_split_words)+1, sequence_lengths


def get_batches(x, y, sql, batch_size=1):
    for ii in range(0, len(y), batch_size):
        yield x[ii:ii + batch_size], y[ii:ii + batch_size], sql[ii:ii+batch_size]


def plot(noOfWrongPred, dataPoints):
    font_size = 14
    fig = plt.figure(dpi=100,figsize=(10, 6))
    mplt.rcParams.update({'font.size': font_size})
    plt.title("Distribution of wrong predictions", fontsize=font_size)
    plt.ylabel('Error rate', fontsize=font_size)
    plt.xlabel('Number of data points', fontsize=font_size)

    plt.plot(dataPoints, noOfWrongPred, label='Prediction', color='blue', linewidth=1.8)
    # plt.legend(loc='upper right', fontsize=14)

    plt.savefig('distribution of wrong predictions.png')
    # plt.show()



def train_test():
    features, labels, n_words, sequence_length = pre_process()

    print(features.shape)
    print(labels.shape)

    # Defining Hyperparameters

    lstm_layers = 1
    batch_size = 1
    lstm_size = 200
    learning_rate = 0.01

    # --------------placeholders-------------------------------------

    # Create the graph object
    graph = tf.Graph()
    # Add nodes to the graph
    with graph.as_default():

        tf.set_random_seed(1)

        inputs_ = tf.placeholder(tf.int32, [None, None], name="inputs")
        # labels_ = tf.placeholder(dtype= tf.int32)
        labels_ = tf.placeholder(tf.float32, [None, None], name="labels")
        sql_in = tf.placeholder(tf.int32, [None], name= 'sql_in')

        # output_keep_prob is the dropout added to the RNN's outputs, the dropout will have no effect on the calculation of the subsequent states.
        keep_prob = tf.placeholder(tf.float32, name="keep_prob")

        # Size of the embedding vectors (number of units in the embedding layer)
        embed_size = 300

        # generating random values from a uniform distribution (minval included and maxval excluded)
        embedding = tf.Variable(tf.random_uniform((n_words, embed_size), -1, 1),trainable=True)
        embed = tf.nn.embedding_lookup(embedding, inputs_)

        print(embedding.shape)
        print(embed.shape)
        print(embed[0])

        # Your basic LSTM cell
        lstm =  tf.contrib.rnn.BasicLSTMCell(lstm_size)

        # Getting an initial state of all zeros
        initial_state = lstm.zero_state(batch_size, tf.float32)

        outputs, final_state = tf.nn.dynamic_rnn(lstm, embed, initial_state=initial_state, sequence_length=sql_in)

        out_batch_size = tf.shape(outputs)[0]
        out_max_length = tf.shape(outputs)[1]
        out_size = int(outputs.get_shape()[2])
        index = tf.range(0, out_batch_size) * out_max_length + (sql_in - 1)
        flat = tf.reshape(outputs, [-1, out_size])
        relevant = tf.gather(flat, index)

        # hidden layer
        hidden = tf.layers.dense(relevant, units=25, activation=tf.nn.relu,trainable=True)

        print(hidden.shape)

        logit = tf.contrib.layers.fully_connected(hidden, num_outputs=20, activation_fn=None)

        cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logit, labels=labels_))


        optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)


        saver = tf.train.Saver()

    # ----------------------------online training-----------------------------------------

    with tf.Session(graph=graph) as sess:
        tf.set_random_seed(1)
        sess.run(tf.global_variables_initializer())
        iteration = 1
        state = sess.run(initial_state)
        wrongPred = 0
        noOfWrongPreds = []
        dataPoints = []

        for ii, (x, y, sql) in enumerate(get_batches(features, labels, sequence_length, batch_size), 1):

            feed = {inputs_: x,
                    labels_: y,
                    sql_in : sql,
                    keep_prob: 0.5,
                    initial_state: state}

            predictions = tf.nn.softmax(logit).eval(feed_dict=feed)

            print("----------------------------------------------------------")
            print("sez: ",sql)
            print("Iteration: {}".format(iteration))

            isequal = np.equal(np.argmax(predictions[0], 0), np.argmax(y[0], 0))

            print(np.argmax(predictions[0], 0))
            print(np.argmax(y[0], 0))

            if not (isequal):
                wrongPred += 1

            print("nummber of wrong preds: ",wrongPred)

            if iteration%50 == 0:
                noOfWrongPreds.append(wrongPred/iteration)
                dataPoints.append(iteration)

            loss, states, _ = sess.run([cost, outputs, optimizer], feed_dict=feed)

            print("Train loss: {:.3f}".format(loss))
            iteration += 1

        saver.save(sess, "checkpoints/sentiment.ckpt")
        errorRate = wrongPred / len(labels)
        print("ERRORS: ", wrongPred)
        print("ERROR RATE: ", errorRate)
        plot(noOfWrongPreds, dataPoints)


if __name__ == '__main__':
    train_test()

编辑

enter image description here

1 个答案:

答案 0 :(得分:0)

没什么要考虑的-:

  1. 绘制损耗与迭代次数图。了解您的网络正在学习应该向下。您可以使用 tensorboard 生成这些图。还会产生准确性与迭代次数的关系。
  2. 将批量大小从1增加到64,128的小批量,具体取决于您的系统配置(RAM)
  3. 使用双向LSTM ,因为在训练模型之前您需要完整的句子以提高准确性。

编辑

您的模型无法正确学习权重。 运行您的代码,模型仅预测类0。请看一下您的预测和预测1。预测始终为0。

迭代次数:1 0 10 错误的数量:1

火车损失:3.116

迭代次数:2 0 3 错误的数量:2

火车损失:3.163

迭代次数:3 0 17 错误的数量:3

火车损失:3.212

迭代次数:4 0 3 错误的数量:4

火车损失:2.992

迭代次数:5 0 4 错误的数量:5

火车损失:2.892

迭代次数:6 0 12 错误的数量:6

火车失窃:3.077

迭代次数:7 0 4 错误的举动数量:7

火车损失:2.546

迭代次数:8 0 10 错误的数量:8

火车损失:3.459

迭代次数:9 0 10 错误的举动数量:9

火车损失:2.341

迭代次数:10 0 19 错误的数量:10

火车损失:3.303

迭代次数:11 0 19 错误的举动数量:11

火车损失:3.193

迭代次数:12 0 11 错误的举动数量:12

火车损失:3.323

迭代次数:13 0 19 错误的举动数量:13

火车损失:2.773

迭代次数:14 0 13 错误的举动数量:14

火车损失:3.129

迭代次数:15 0 0 错误的举动数量:14

火车损失:3.992

迭代次数:16 0 17 错误的数量:15

火车损失:3.010

迭代次数:17 0 12 错误的举动数量:16

火车损失:2.534

迭代次数:18 0 12 错误的举动数量:17

火车损失:2.804

迭代次数:19 0 11 错误的举动数量:18

火车损失:4.369

迭代次数:20 0 8 错误的举动数量:19

火车损失:4.028

迭代次数:21 0 7 错误的数量:20

火车损失:3.844

迭代次数:22 0 5 错误的举动数量:21

火车损失:3.579

迭代次数:23 0 1个 错误的举动数量:22

火车损失:3.418

迭代次数:24 0 8 错误的举动数量:23

火车损失:4.337

迭代次数:25 0 10 错误的举动数量:24

火车损失:2.328

迭代次数:26 0 14 错误的举动数量:25

火车损失:4.216

迭代次数:27 0 16 错误的举动数量:26

火车损失:3.155

迭代次数:28 0 1个 错误的举动数量:27

火车损失:3.307

迭代次数:29 0 6 错误的数量:28

火车损失:3.744

迭代次数:30 0 0 错误的数量:28

火车损失:4.180

迭代次数:31 0 7 错误的数量:29

火车失窃:3.400

迭代次数:32 0 16 错误的数量:30

火车损失:2.706

迭代次数:33 0 5 错误的数量:31

火车损失:2.994

迭代次数:34 0 9 错误的数量:32

火车损失:3.610

迭代次数:35 0 13 错误的举动数量:33

火车损失:2.689

迭代次数:36 0 4 错误的举动数量:34

火车损失:2.755

迭代次数:37 0 4 错误的举动数量:35

火车损失:2.778

迭代次数:38 0 18岁 错误的举动数量:36

火车损失:3.361

迭代次数:39 0 8 错误的举动数量:37

火车损失:3.640

迭代次数:40 0 8 错误的举动数量:38

火车损失:3.276

迭代次数:41 0 19 错误的举动数量:39

火车损失:2.796

迭代次数:42 0 1个 错误的举动数量:40

火车失窃:3.189

迭代次数:43 0 12 错误的举动数量:41

火车损失:2.901

迭代次数:44 0 7 错误的举动数量:42

火车损失:2.913

迭代次数:45 0 10 错误的举动数量:43

火车损失:2.875

迭代次数:46 0 5 错误的举动数量:44

火车损失:3.005

迭代次数:47 0 2 错误的数量:45

火车损失:3.246

迭代次数:48 0 6 错误的举动数量:46

火车损失:3.071

迭代次数:49 0 11 错误的举动数量:47

火车损失:2.971

迭代次数:50 0 2 错误的举动数量:48

火车损失:3.192

迭代次数:51 0 12 错误的举动数量:49

火车损失:2.894

迭代次数:52 0 7 错误的数量:50

火车损失:2.980