从predict_proba中找出课程?

时间:2016-02-01 15:01:19

标签: python scikit-learn

我使用scikit-learn中的predict_proba来查找文档被分配到特定主题的概率。我没有打印出前1个主题并将输入文档X分配给该主题Y,而是有兴趣打印出前5个概率来验证分类是否一致。但是,如何在88个主题中找出这些前5个概率属于哪些主题。

以下是代码和输出:

model = LogisticRegression()
model = model.fit(matrix_tmp, label_tmp)

y_train_pred = model1.predict_log_proba(matrix_tmp_test)
order=np.argsort(y_train_pred, axis=1)
print(order[:, -5:])

所以这会打印出如下矩阵:

[[38 11  6 66  0]
 [20 13 11  0  1]
 [61 11  0 13  1]
 ..., 
 [19 30 13  0  1]
 [13 34 75  0  1]
 [ 0 46  3  1 40]]

根据排序,0表示概率最高的主题,66表示最高的主题,依此类推。我的问题是如何在矩阵中找出这些数字的主题。总共有88个主题(根据 model.classes _ 标记为0到87,其中只考虑前5个。那么如何以类似的方式打印主题?< / p>

1 个答案:

答案 0 :(得分:0)

你几乎拥有它。我认为以下简单的解决方案应该有效(我很快就测试了它,它至少对我有用):

print(model.classes_[order[:, -5:]])