使用SVM预测概率

时间:2018-03-27 07:42:59

标签: python classification svm libsvm

我编写了这段代码,希望获得分类概率。

from sklearn import svm
X = [[0, 0], [10, 10],[20,30],[30,30],[40, 30], [80,60], [80,50]]
y = [0, 1, 2, 3, 4, 5, 6]
clf = svm.SVC() 
clf.probability=True
clf.fit(X, y)
prob = clf.predict_proba([[10, 10]])
print prob

我获得了这个输出:

[[0.15376986 0.07691205 0.15388546 0.15389275 0.15386348 0.15383004 0.15384636]]

这很奇怪,因为概率应该是

[0 1 0 0 0 0 0 0]

(观察到必须预测哪个类的样本与第二个样本相同),该类获得的概率最低。

3 个答案:

答案 0 :(得分:3)

编辑:正如@TimH所指出的,概率可以由clf.decision_function(X)给出。以下代码是固定的。使用predict_proba(X)注意到低概率的指定问题,我认为答案是根据官方文档here ....而且,它会在非常小的数据集上产生无意义的结果。< / em>的

理解SVM的最终可能性的答案。 简而言之,您在2D平面上有7个等级和7个点。 SVM正在尝试做的是在每个类和每个类之间找到一个线性分隔符(一对一方法)。每次只选择2个班级。 你得到的是归一化后分类器的投票。请参阅this帖子或here中有关 libsvm 的多类SVM的详细说明(scikit-learn使用libsvm)。

通过略微修改代码,我们看到确实选择了正确的类:

from sklearn import svm
import matplotlib.pyplot as plt
import numpy as np


X = [[0, 0], [10, 10],[20,30],[30,30],[40, 30], [80,60], [80,50]]
y = [0, 1, 2, 3, 4, 5, 6]
clf = svm.SVC() 
clf.fit(X, y)

x_pred = [[10,10]]
x_pred = [[10,10]]
p = np.array(clf.decision_function(X)) # decision is a voting function
prob = np.exp(p)/np.sum(np.exp(p),axis=1) # softmax after the voting
classes = clf.predict(X)

_ = [print('Sample={}, Prediction={},\n Votes={} \nP={}, '.format(idx,c,v, s)) for idx, (v,s,c) in enumerate(zip(p,prob,classes))]

相应的输出是

Sample=0, Prediction=0,
Votes=[ 6.5         4.91666667  3.91666667  2.91666667  1.91666667  0.91666667 -0.08333333] 
P=[ 0.75531071  0.15505748  0.05704246  0.02098475  0.00771986  0.00283998  0.00104477], 
Sample=1, Prediction=1,
Votes=[ 4.91666667  6.5         3.91666667  2.91666667  1.91666667  0.91666667 -0.08333333] 
P=[ 0.15505748  0.75531071  0.05704246  0.02098475  0.00771986  0.00283998  0.00104477], 
Sample=2, Prediction=2,
Votes=[ 1.91666667  2.91666667  6.5         4.91666667  3.91666667  0.91666667 -0.08333333] 
P=[ 0.00771986  0.02098475  0.75531071  0.15505748  0.05704246  0.00283998  0.00104477], 
Sample=3, Prediction=3,
Votes=[ 1.91666667  2.91666667  4.91666667  6.5         3.91666667  0.91666667 -0.08333333] 
P=[ 0.00771986  0.02098475  0.15505748  0.75531071  0.05704246  0.00283998  0.00104477], 
Sample=4, Prediction=4,
Votes=[ 1.91666667  2.91666667  3.91666667  4.91666667  6.5         0.91666667 -0.08333333] 
P=[ 0.00771986  0.02098475  0.05704246  0.15505748  0.75531071  0.00283998  0.00104477], 
Sample=5, Prediction=5,
Votes=[ 3.91666667  2.91666667  1.91666667  0.91666667 -0.08333333  6.5  4.91666667] 
P=[ 0.05704246  0.02098475  0.00771986  0.00283998  0.00104477  0.75531071  0.15505748], 
Sample=6, Prediction=6,
Votes=[ 3.91666667  2.91666667  1.91666667  0.91666667 -0.08333333  4.91666667  6.5       ] 
P=[ 0.05704246  0.02098475  0.00771986  0.00283998  0.00104477  0.15505748  0.75531071], 

您还可以看到决策区:

X = np.array(X)
y = np.array(y)
fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(111)

XX, YY = np.mgrid[0:100:200j, 0:100:200j]
Z = clf.predict(np.c_[XX.ravel(), YY.ravel()])

Z = Z.reshape(XX.shape)
plt.figure(1, figsize=(4, 3))
plt.pcolormesh(XX, YY, Z, cmap=plt.cm.Paired)

for idx in range(7):
    ax.scatter(X[idx,0],X[idx,1], color='k')

enter image description here

答案 1 :(得分:2)

您应该停用probability并改为使用decision_function,因为无法保证predict_probapredict返回相同的结果。 您可以在documentation

中详细了解相关信息
clf.predict([[10, 10]]) // returns 1 as expected 

prop = clf.decision_function([[10, 10]]) // returns [[ 4.91666667  6.5         3.91666667  2.91666667  1.91666667  0.91666667
      -0.08333333]]
prediction = np.argmax(prop) // returns 1 

答案 2 :(得分:1)

你可以read in the docs ......

  

SVC方法decision_function给出每个样本的每个类别得分(或二进制情况下每个样本的单个得分)。当构造函数选项概率设置为True时,将启用类成员概率估计(来自方法predict_proba和predict_log_proba)。在二元情形中,概率使用Platt缩放校准:对SVM分数的逻辑回归,通过对训练数据的额外交叉验证拟合。在多类情况下,根据Wu等人的说法进行了扩展。 (2004)。

     

毋庸置疑, Platt缩放中涉及的交叉验证对于大型数据集来说是一项昂贵的操作此外,在得分的“argmax”可能不是概率的argmax的意义上,概率估计可能与得分不一致。 (例如,在二进制分类中,样本可以通过预测标记为属于具有概率&lt;½的类根据predict_proba 。) Platt的方法也已知具有理论问题。 如果需要置信度,但这些不一定是概率,那么建议设置probability = False并使用decision_function而不是predict_proba。

Stack Overflow用户对此功能也存在很多困惑,您可以在this threadthis one中看到。