Sklearn预测多个输出

时间:2016-11-23 17:19:11

标签: python python-3.x scikit-learn

我写了以下代码:

from sklearn import tree

# Dataset & labels
# Using metric units
# features = [height, weight, style]
styles = ['modern', 'classic']
features = [[1.65, 65, 1], 
            [1.55, 50, 1],
            [1.76, 64, 0],
            [1.68, 77, 0] ]
labels = ['Yellow dress', 'Red dress', 'Blue dress', 'Green dress']

# Decision Tree
clf = tree.DecisionTreeClassifier()
clf = clf.fit(features, labels)

# Returns the dress
height = input('Height: ')
weight = input('Weight: ')
style = input('Modern [0] or Classic [1]: ')
print(clf.predict([[height,weight,style]]))

此代码接收用户的身高和体重,然后返回更适合她的着装。有没有办法返回多个选项?例如,返回两件或更多件连衣裙。

更新

from sklearn import tree
import numpy as np

# Dataset & labels
# features = [height, weight, style]
# styles = ['modern', 'classic']
features = [[1.65, 65, 1], 
            [1.55, 50, 1],
            [1.76, 64, 1],
            [1.72, 68, 0],
            [1.73, 68, 0],
            [1.68, 77, 0]]
labels =    ['Yellow dress',
            'Red dress',
            'Blue dress',
            'Green dress',
            'Purple dress',
            'Orange dress']

# Decision Tree
clf = tree.DecisionTreeClassifier()
clf = clf.fit(features, labels)

# Returns the dress
height = input('Height: ')
weight = input('Weight: ')
style = input('Modern [0] or Classic [1]: ')

print(clf.predict_proba([[height,weight,style]]))

如果用户是1.72米和68公斤,我想要同时显示绿色和紫色礼服。这个例子只返回100%的绿色礼服。

3 个答案:

答案 0 :(得分:5)

是的,你可以。实际上你可以做的是你可以获得每个班级的概率。有一个名为.predict_proba()的函数在某些分类器中实现。

请参阅here,sklearn的文档。

它将返回每个班级样本的成员概率。

然后,您可以返回与两个,三个最高概率相关联的标签。

答案 1 :(得分:2)

您可以使用predict_proba()方法获取给定项目的每个类的概率。有关更多信息,请查看" predict_proba()":

http://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html

希望这会有所帮助..

答案 2 :(得分:2)

predict()将仅返回概率较高的类。如果你使用 predict_proba()相反,它将返回一个数组,其中包含每个类的概率,因此您可以选择高于某个阈值的数组。

Here是该方法的文档。

你可以用它做这样的事情:

probs = clf.predict_proba([[height, weight, style]])
threshold = 0.25 # change this accordingly
for index, prob in enumerate(probs[0]):
    if prob > threshold:
        print(styles[index])