如何在Scikitlearn中绘制S形概率曲线?

时间:2016-04-29 21:46:58

标签: python machine-learning scikit-learn

我正在尝试使用Python给出2个类及其与分类器相关联的预测概率来重新创建此图像。

我希望看到类似的内容: sigmoid curve

虽然它不起作用,因为我得到的主要是线性线。 **注意:我知道显示的数据目前是可疑和/或不好的。我需要调整输入和输出模特,但想看看情节

基本上,我认为我“更正”predict_proba()输出所以它们都是关于“0”类(意思是如果它预测“1”类,它是“0”类的概率是1-(1classProbability)使95%预测它的类“1”变为5%改变它的类“0”。然后按照我校正的预测值的顺序排序并最终得到sigmoid-ish。

不幸的是,我最终得到了这个: enter image description here

这是我的python的一大块,我正在尝试(不成功)绘制sigmoid的概率:

###########################
 ## I removed my original Python code because it was very, very wrong so as to avoid any confusion.
###########################

作为参考,下面是Matlab中我试图在我的Python模型中复制的图。

%Build the model
mdl = fitglm(X, Y, 'distr', 'binomial', 'link', 'logit')
%Build the sigmoid model
B = mdl.Coefficients{:, 1};
Z = mdl.Fitted.LinearPredictor
yhat = glmval(B, X, 'logit'); 
figure, scatter(Z, yhat), hold on,
gscatter(Z, zeros(length(X),1)-0.1, Y) % plot original classes
hold off, xlabel('\bf Z'),  grid on,  ylim([-0.2 1.05])
title('\bf Predicted Probability of each record')

1 个答案:

答案 0 :(得分:0)

可能有更多的pythonic方法可以做到这一点,但这是我最终能够想到的:

(请记住,在这种情况下,数据没有很好地分开,因此曲线不具有传统的外观,在S形曲线的0.50点处分类。)

#############################################################################
#### Draws a sigmoid probability plot from prediction results ###############
#############################################################################
import matplotlib.pyplot as plt
import numpy as np
print ('-'*40)

# make the predictions (class) and also get the prediction probabilities 
y_train_predict = clf.predict(X_train)
y_train_predictProbas = clf.predict_proba(X_train)
y_train_predictProbas = y_train_predictProbas[:, 1]

y_test_predict = clf.predict(X_test)
y_test_predictProbas = clf.predict_proba(X_test)
y_test_predictProbas = y_test_predictProbas[:, 1]

#Get the thetas from the model
thetas = clf.coef_[0]
intercept = clf.intercept_[0]
print 'thetas='
print thetas
print 'intercept='
print intercept

#Display the predictors and their associated Thetas
for idx, x in enumerate(thetas):
    print "Predictor: " + str(labels[idx+1]) + "=" + str(x)

#append intercept to thetas (because scikitlearn doesn't normally output theta0
interceptAndThetas = np.append([intercept],thetas)
X_testWithThetaZero = []
for row in X_test:
    X_testWithThetaZero.append(np.append([1],row))

#Calculate the dot product for plotting the sigmoid
dotProductResult = []    
for idx, x in enumerate(X_testWithThetaZero):
    dotProductResult.append(np.dot( x, interceptAndThetas))    


fig, ax1 = plt.subplots()

wrongDotProducts = []
rightDotProducts = []
#Build the plot
for idx in range(0,len(dotProductResult)):
    #plot the predicted value on the sigmoid curve
    if y_test[idx] == 1:
        ax1.scatter(dotProductResult[idx], y_test_predictProbas[idx], c=['green'],linewidths=0.0)
    else:
        ax1.scatter(dotProductResult[idx], y_test_predictProbas[idx], c=['black'],linewidths=0.0)

    #plot the actual
    if y_test[idx] == 1:
        ax1.scatter(dotProductResult[idx], y_test[idx], c=['green'],linewidths=0.0)
        #determine which ones are "wrong" so we can make a histogram
        if y_test_predictProbas[idx] < 0.5:
            wrongDotProducts.append(dotProductResult[idx])
        else:
            rightDotProducts.append(dotProductResult[idx])
    else:
        ax1.scatter(dotProductResult[idx], y_test[idx], c=['black'],linewidths=0.0)
        #determine which ones are "wrong" so we can make a histogram
        if y_test_predictProbas[idx] > 0.5:
            wrongDotProducts.append(dotProductResult[idx])
        else:
            rightDotProducts.append(dotProductResult[idx])        

#plt.xlim([-0.05, numInstances + 0.05])
plt.ylim([-0.05, 1.05])
plt.xlabel('x')
plt.grid(which="major", axis='both',markevery=0.10)  # which='major',
plt.ylabel('Prediction Probability')
plt.title('Sigmoid Curve & Histogram of Predictions')


## Add a histogram to show where the correct/incorrect prediction distributions
ax2 = ax1.twinx()
ax2.hist(wrongDotProducts, bins=[-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7], hatch="/", color="black", alpha=0.2)
ax2.hist(rightDotProducts, bins=[-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7], hatch="\\", color="green", alpha=0.2)

ax2.set_ylabel('Histogram Count of Actual Class\n1=Green 0=Black')
ax2.set_xlabel('')
ax2.set_title('')
plt.show()    

enter image description here