多标签二值化器在预测时抛出错误

时间:2019-04-23 05:09:00

标签: python-3.x conv-neural-network multilabel-classification

我的数据集有6个目标标签,这是一个多标签分类问题。我建立了一个CNN进行分类,并在语料库上进行了嵌入训练。使用多标签Binarizer预测标签时遇到问题。

模型架构

MAX_VOCAB_SIZE = len(word_index)
embedding_layer = Embedding(MAX_VOCAB_SIZE, \
                            EMBED_SIZE, \
                            input_length=MAX_SEQUENCE_LENGTH)


seq_input = Input(shape=(MAX_SEQUENCE_LENGTH,),dtype='int32')
embedded_seq = embedding_layer(seq_input)
x_1 = Dropout(DROP_RATE_EMBEDDING)(embedded_seq)
x_1 = Conv1D(filters=FILTER_LENGTH,\
            name='1DCNN_1',\
            kernel_size=KERNEL_SIZE,\
            padding='valid',\
            activation='relu',\
            strides=STRIDE)(x_1)
x_1 = GlobalMaxPool1D()(x_1)
preds = Dense(len(nb_classes),activation='sigmoid')(x_1)

model = Model(inputs=seq_input,output=preds)
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['categorical_accuracy'])
model.summary()

enter image description here

这是预测代码。

model = load_model()
test_sentence_token = nlp.tokenizer(test_sentence) # Spacy tokenizer
test_sentence_token = [token.text for token in test_sentence_token if not token.is_stop]

tokenizer = text.Tokenizer(num_words=MAX_FEATURES,lower=True)
test_sentence_seq = tokenizer.texts_to_sequences(test_sentence_token)
test_sentence_pad = pad_sequences(test_sentence_seq, maxlen=MAX_SEQUENCE_LENGTH)
prediction = model.predict(test_sentence_pad)
print(prediction)

multilabel_binarizer = joblib.load(os.path.join(M_PATH,MULTI_LABEL_BINARIZER_FILE))
multilabel_binarizer.inverse_transform(prediction)

当我从X_test传递了一条记录时,我得到了这个错误

[[0.0188026  0.29032567 0.02003733 0.0379594  0.5441595  0.26558512]]

    ---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-29-825be263164a> in <module>
      6 multilabel_binarizer = joblib.load(os.path.join(M_PATH,MULTI_LABEL_BINARIZER_FILE))
----> 7 multilabel_binarizer.inverse_transform(prediction)

~/anaconda3/envs/pp/lib/python3.6/site-packages/sklearn/preprocessing/label.py in inverse_transform(self, yt)
    969             if len(unexpected) > 0:
    970                 raise ValueError('Expected only 0s and 1s in label indicator. '
--> 971                                  'Also got {0}'.format(unexpected))
    972             return [tuple(self.classes_.compress(indicators)) for indicators
    973                     in yt]

ValueError: Expected only 0s and 1s in label indicator. Also got [0.0188026  0.02003733 0.0379594  0.26558512 0.29032567 0.5441595 ]

我已经腌制了我的MLB并加载了它。当我加载并预测时。通过句子时出现以下错误。

 test_sentence = 'in addition glue adhesion and its degradation was also measured'
Loaded model from disk
[[0.04990998 0.03565711 0.21524188 0.16965532 0.338592   0.47556564]
 [0.04990998 0.03565711 0.21524188 0.16965532 0.338592   0.47556564]
 [0.04990998 0.03565711 0.21524188 0.16965532 0.338592   0.47556564]
 [0.04990998 0.03565711 0.21524188 0.16965532 0.338592   0.47556564]
 [0.04990995 0.03565711 0.21524191 0.16965532 0.338592   0.47556564]
 [0.04990995 0.03565711 0.21524192 0.16965534 0.338592   0.47556564]]


---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-27-0cdd1d25b589> in <module>
     22     prediction = model.predict(test_sentence_pad)
     23     print(prediction)
---> 24     multilabel_binarizer.inverse_transform(prediction)
     25 else:
     26     X_train, X_val, y_train, y_val, nb_classes, word_index = load_data(df)

~/anaconda3/envs/pp/lib/python3.6/site-packages/sklearn/preprocessing/label.py in inverse_transform(self, yt)
    969             if len(unexpected) > 0:
    970                 raise ValueError('Expected only 0s and 1s in label indicator. '
--> 971                                  'Also got {0}'.format(unexpected))
    972             return [tuple(self.classes_.compress(indicators)) for indicators
    973                     in yt]

ValueError: Expected only 0s and 1s in label indicator. Also got [0.03565711 0.04990995 0.04990998 0.16965532 0.16965534 0.21524188
 0.21524191 0.21524192 0.338592   0.47556564]

0 个答案:

没有答案