如何使用sklearn的SGDClassifier获得前N个预测

时间:2018-10-08 09:05:10

标签: python scikit-learn multilabel-classification

我尝试使用scikit的SGDClassifier设置一个简单的文本分类任务,并尝试重新获得前N个预测(包括概率)。作为样本训练数据,我有三个课程

  • 苹果
  • 柠檬
  • 橙色

每堂课只有一个文档:

  • 在苹果中:“苹果和柠檬”
  • 在柠檬中:“柠檬和橙子”
  • 橘子:“橙和苹果”

我现在想预测三个测试文档“ apple”,“ lemon”和“ orange”,并希望获得每个文档的前2个预测,包括它们的能力。到目前为止,我的代码如下:

from sklearn.linear_model import SGDClassifier
from sklearn.datasets import load_files
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.pipeline import Pipeline
import numpy as np

train = load_files('data/test/')

text_clf_svm = Pipeline([('vect', CountVectorizer()), ('tfidf', TfidfTransformer()),
                     ('clf-svm', SGDClassifier(loss='modified_huber', penalty='l2',alpha=1e-3, n_iter=5, random_state=42))])
text_clf_svm = text_clf_svm.fit(train.data, train.target)

docs=['apple', 'orange', 'lemon']
predicted = text_clf_svm.predict(docs)
#Perform a Top 1 prediction
for doc, category in zip(docs, predicted):
    print('%r => %s' % (doc, train.target_names[category]))

# Perform a top 2 prediction
print(np.argsort(text_clf_svm.predict_proba(docs), axis=1)[-2:])

我的输出如下:

'apple' => apples
'orange' => lemons
'lemon' => lemons
[[1 2 0]
[0 1 2]]

我现在很难解释数据。我真正想出去的是:

'apple' => apples (0.54...), lemons (0.43...)
'orange' => apples (0.48...), oranges (0.43...)
'lemon' => lemons (0.48...), oranges (0.43...)

有人可以告诉我我该怎么做吗?预先感谢您的帮助!

2 个答案:

答案 0 :(得分:1)

您正在使用argsort,argsort所做的是它为您提供了已排序数组的索引,因此您应该执行以下操作:

List<Shop> values = get using api;
getShopBox().put(values);


getShopBox().getAll();// does not work after updating

只需将打印重新格式化为您的样式,您将拥有想要的东西:)

答案 1 :(得分:0)

@Imtinan答案的快速补充,因为该答案将您的标签排名为第二高,然后第一高(升序)。相反,如果您希望按降序排列,只需修改:

preds_idx = np.argsort(-preds, axis = 1)[ :2]