我正在将OneVsRestClassifier用于svm.SVC作为基本估计量的多类问题。
predict_proba()
中的argmax与预测的类不匹配:
后台是否正在进行某种标准化?我如何获得predict_proba()和predict() 匹配?
答案 0 :(得分:0)
根据scikit learn's SVC documentation on multi-class classification,predict
的输出和argmax
的{{1}}的输出之间存在差异(强调我):
SVC和NuSVC的Decision_function方法为每个样本提供每个类别的分数(或在二进制情况下为每个样本单个分数)。当构造函数选项概率设置为True时,将启用类成员资格概率估计(来自predict_proba和predict_log_proba方法)。在二进制情况下,概率使用Platt缩放比例进行校准:对SVM得分进行逻辑回归,并通过对训练数据进行额外的交叉验证来拟合。在多类情况下,这根据Wu等人进行了扩展。 (2004)。
毋庸置疑,Platt缩放所涉及的交叉验证对于大型数据集而言是一项昂贵的操作。 此外,在分数的“ argmax”可能不是概率的argmax的意义上,概率估计可能与分数不一致。 (例如,在二进制分类中,根据predict_proba,样本可能被预测标记为属于概率<1/2的一类。)普氏方法也存在理论问题。如果需要置信度分数,但不一定非要是概率,那么建议将概率设置为False,并使用Decision_function而不是predict_proba。
您无法使用SVC使它们匹配。如果需要概率,可以尝试其他模型。如文档中所述,如果不需要概率,则可以使用predict_proba
(有关更多详细信息,请参见here。)