使用RNN keras对文本进行分类

时间:2018-06-12 10:47:39

标签: python tensorflow keras rnn

我试图理解RNN for text classification using keras/tensorflow。目前,它可以对正面/负面情绪进行分类。我怎么能把它改成其他类呢?例如,2个类QuestionNot-Question

    # LSTM for sequence classification in the IMDB dataset
    import numpy
    from keras.datasets import imdb
    from keras.models import Sequential
    from keras.layers import Dense
    from keras.layers import LSTM
    from keras.layers.embeddings import Embedding
    from keras.preprocessing import sequence
    # fix random seed for reproducibility
    numpy.random.seed(7)

    # load the dataset but only keep the top n words, zero the rest
    top_words = 5000

    (X_train, y_train), (X_test, y_test) = imdb.load_data(num_words=top_words)

    # truncate and pad input sequences
    max_review_length = 500
    X_train = sequence.pad_sequences(X_train, maxlen=max_review_length)

    X_test = sequence.pad_sequences(X_test, maxlen=max_review_length)
    # create the model
    embedding_vecor_length = 32
    model = Sequential()
    model.add(Embedding(top_words, embedding_vecor_length, input_length=max_review_length))
    model.add(LSTM(100))
    model.add(Dense(1, activation='sigmoid'))
    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    print(model.summary())
    model.fit(X_train, y_train, epochs=3, batch_size=64)
    # Final evaluation of the model
    scores = model.evaluate(X_test, y_test, verbose=0)
    print("Accuracy: %.2f%%" % (scores[1]*100))

1 个答案:

答案 0 :(得分:0)

当您在<?php $attributes = array('onsubmit' => 'return makeGET(this);'); echo form_open('admincp/search', $attributes); $data = array('name' => 'search', 'id' => 'query'); echo form_input($data); $data4 = array( 'name' => 'submit',); echo form_submit($data4); echo form_close(); <script> var makeGET = function(e){ e.preventDefault(); var query = document.getElementById('query'); var url = "<?php echo base_url('admincp/search/');" + query; window.location = url; } </script> ?> 的最后一层中有Dense(1, activation='sigmoid')时,您正在制作二进制,2类分类。因此,您可以使用相同的模型来学习问题,无问题设置。

如果你想要超过2个类,那么我们经常使用loss='binary_crossentropy'Dense(num_classes, activation='softmax')来产生可能类的概率分布。事实上,二进制分类只是loss='categorical_crossentropy'的一个特例,其中类是Dense(2, activation='softmax')和“[0,1]”,即目标类的单热编码。