了解Decision_Function值

时间:2018-06-21 22:36:16

标签: python machine-learning scikit-learn

我目前正处于第一次机器学习的中期,到目前为止,我还没有完全了解从decision_function(X)获得的值的规模(也无法理解它们)。

基于sklearn documentation decision_function(X)旨在:

  

预测样品的置信度得分。

尽管如此,在运行以下脚本时:

from sklearn.datasets import fetch_mldata
import numpy as np
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix , precision_score, recall_score


mnist = fetch_mldata('MNIST original')

classifier = SGDClassifier(random_state = 42, max_iter = 5)


X,y = mnist["data"], mnist["target"]
some_digit = X[36001]
some_digit_image = some_digit.reshape(28, 28)

X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

random_order = np.random.permutation(60000)

X_train, y_train = X_train[random_order], y_train[random_order]

y_test_5 = (y_test == 5)
y_train_5 = (y_train == 5)


classifier.fit(X_train, y_train_5)
print(classifier.decision_function([X_test[1]]))

它现在为[-289809.39489525]打印decision_function,我不确定如何阅读或评估这些值(我希望看到百分比) 。如果有人可以向我解释这些读数的含义,将不胜感激。

非常感谢您。

1 个答案:

答案 0 :(得分:3)

如何获取概率(百分比)?

使用predict_proba方法。

什么是 decision_function

由于SGDClassifier是线性模型,因此decision_function向分离的超平面输出一个有符号的距离。此数字只是<< strong> w , x > + b或转换为scikit-learn属性名称 <coef_ x > + intercept_