sklearn.model_selection.permutation_test_score和sklearn.metrics.r2_score计算的r2得分之间存在显着的不匹配。由permutation_test_score计算的那个似乎是不正确的;见下文:
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
from sklearn.model_selection import permutation_test_score
x=np.arange(1,6,1)
x=np.reshape(x,(5,1))
y=np.array([1.9, 3.7, 5.8, 8.0, 9.6])
y=np.reshape(y,(5,1))
# fit a line to the data
lin_reg = LinearRegression()
lin_reg.fit(x, y)
print lin_reg.intercept_,lin_reg.coef_
# 1.97 -0.11
# Compute the prediction values (f) from our fitted line
f=lin_reg.predict(x)
print f
# [[ 1.86]
# [ 3.83]
# [ 5.8 ]
# [ 7.77]
# [ 9.74]]
# Calculate R^2 explicitly
yminusf2=(y-f)**2
sserr=sum(yminusf2)
mean=float(sum(y))/float(len(y))
yminusmean2=(y-mean)**2
sstot=sum(yminusmean2)
R2=1.-(sserr/sstot)
print R2
# 0.99766067
# Use sklearn.metrics.r2_score
print r2_score(y,f)
# 0.99766066838
print r2_score(y,f) == R2
# [ True]
# Use sklearn.model_selection.permutation_test_score
r2_sc, perm_sc, pval = permutation_test_score(lin_reg,x,y, n_permutations = 100, scoring = 'r2', cv =None)
print r2_sc
# 0.621593653548
print r2_sc ==R2
# [False]
答案 0 :(得分:1)
是的,它有所不同。您正在获得整个数据的分数(即拟合x
并且也预测相同)。所以R2
和r2_score()非常高。但至于你的问题,它有所不同,因为permutation_test_score()
并没有计算整个数据的得分,而是使用交叉验证技术并输出所有折叠得分的平均值。
请注意,permutation_test_score()
也有一个参数cv
,如果未指定参数,则为“无”,默认为3倍交叉验证(等于KFold(3))为{{3} }:
cv:int,交叉验证生成器或可迭代的可选
Determines the cross-validation splitting strategy. Possible inputs for cv are: - None, to use the default 3-fold cross validation, - integer, to specify the number of folds in a (Stratified)KFold, - An object to be used as a cross-validation generator. - An iterable yielding train, test splits.
因此,permutation_test_score返回的分数是交叉验证获得的所有分数的平均值。
您可以使用specified in the documentation来测试此方案,该cross_val_score
会返回每个折叠的分数:
from sklearn.model_selection import cross_val_score
r2_sc_cv = cross_val_score(lin_reg,x,y, scoring = 'r2', cv =None)
print r2_sc_cv
# array([ 0.91975309, 0.94502787, 0. ])
r2_sc_cv_mean = np.average(r2_sc_cv)
print r2_sc_cv_mean
# 0.62159365354781015
print r2_sc_cv_mean == r2_sc
# True
查看r2_sc_cv
数组的最后一个分数。它的0.0。这就是为什么平均分数会降低的原因。