我编写了此代码,该代码使用分层k折叠来拆分数据集并拟合多项式回归,然后获得准确性。我的X
是一个有19个变量的数组(最后一个是聚类变量),而Y
有3个类(0,1,2)。
X = np.asarray(df[[*all 19 columns here*]], dtype="float64")
y = np.asarray(df["categoric_var"], dtype="int")
acc_test=[]
acc_train=[]
skf = StratifiedKFold(n_splits=5, shuffle=True)
split_n = 0
for train_ix, test_ix in skf.split(X,y):
split_n +=1
X_train, X_valid = X[train_ix], X[test_ix]
y_train, y_valid = y[train_ix], y[test_ix]
cluster_groups = X_train[:,-1]
X_train2 = X_train[:,:-1].astype("float64") # remove clustering variable
X_valid2 = X_valid[:,:-1].astype("float64") # remove clustering variable
mnl = sm.MNLogit(y_train, X_train2).fit(cov_type="cluster", cov_kwds={"groups":cluster_groups})
print(mnl.summary())
train_pred = mnl.predict(X_train2)
# turn predicted probabilities into final classification, into a list
pred_list_train = []
for row in train_pred:
if np.where(row == np.amax(row))[0]==0:
pred_list_train.append(0)
elif np.where(row == np.amax(row))[0]==1:
pred_list_train.append(1)
else:
pred_list_train.append(2)
print('MNLogit Regression, training set, fold ', i, ': ', classification_report(y_train, pred_list_train))
pred = mnl.predict(X_valid2)
# turn predicted probabilities into final classification, into a list
pred_list_test = []
for row in pred:
if np.where(row == np.amax(row))[0]==0:
pred_list_test.append(0)
elif np.where(row == np.amax(row))[0]==1:
pred_list_test.append(1)
else:
pred_list_test.append(2)
#Measure of the fit of the model
print('MNLogit Regression, validation set, fold ', i, ': ', classification_report(y_valid, pred_list_test))
acc_test.append(accuracy_score(y_valid, pred_list_test))
acc_train.append(accuracy_score(y_train, pred_list_train))
问题是我有y
的两个版本,一个版本的类更加不平衡(版本1),另一个版本的类更加平衡(版本2)。
当我在y
的版本1中尝试此代码时,它可以完美地工作。但是,当我尝试在版本2上运行它时,有些折叠会返回所有nan
的回归...这是一个示例(对长度表示歉意)。这是前两折的结果:
C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\discrete\discrete_model.py:2251: RuntimeWarning: divide by zero encountered in log
logprob = np.log(self.cdf(np.dot(self.exog,params)))
C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\discrete\discrete_model.py:2252: RuntimeWarning: invalid value encountered in multiply
return np.sum(d * logprob)
Optimization terminated successfully.
Current function value: nan
Iterations 14
C:\ProgramData\Anaconda3\lib\site-packages\scipy\stats\_distn_infrastructure.py:903: RuntimeWarning: invalid value encountered in greater
return (a < x) & (x < b)
C:\ProgramData\Anaconda3\lib\site-packages\scipy\stats\_distn_infrastructure.py:903: RuntimeWarning: invalid value encountered in less
return (a < x) & (x < b)
C:\ProgramData\Anaconda3\lib\site-packages\scipy\stats\_distn_infrastructure.py:1912: RuntimeWarning: invalid value encountered in less_equal
cond2 = cond0 & (x <= _a)
MNLogit Regression Results
==============================================================================
Dep. Variable: y No. Observations: 13852
Model: MNLogit Df Residuals: 13814
Method: MLE Df Model: 36
Date: Thu, 13 Aug 2020 Pseudo R-squ.: nan
Time: 23:04:09 Log-Likelihood: nan
converged: True LL-Null: -13943.
Covariance Type: cluster LLR p-value: nan
==============================================================================
y=1 coef std err z P>|z| [0.025 0.975]
------------------------------------------------------------------------------
x1 -0.0012 0.009 -0.126 0.900 -0.020 0.017
x2 0.0001 1.8e-05 6.207 0.000 7.63e-05 0.000
x3 -0.6074 0.621 -0.978 0.328 -1.825 0.610
x4 8.5373 1.219 7.004 0.000 6.148 10.926
x5 0.0136 0.002 5.906 0.000 0.009 0.018
x6 0.0024 0.066 0.037 0.970 -0.127 0.131
x7 -0.0060 0.003 -1.972 0.049 -0.012 -3.76e-05
x8 -0.0263 0.015 -1.695 0.090 -0.057 0.004
x9 -0.0237 0.026 -0.926 0.355 -0.074 0.026
x10 -0.0008 0.002 -0.404 0.686 -0.005 0.003
x11 0.0713 0.031 2.308 0.021 0.011 0.132
x12 -9.272e-05 1.54e-05 -6.003 0.000 -0.000 -6.24e-05
x13 -0.0012 0.000 -4.696 0.000 -0.002 -0.001
x14 5.53e-05 1.06e-05 5.215 0.000 3.45e-05 7.61e-05
x15 -0.0007 0.000 -3.538 0.000 -0.001 -0.000
x16 7.334e-05 6.94e-05 1.056 0.291 -6.27e-05 0.000
x17 -0.0098 0.001 -9.659 0.000 -0.012 -0.008
x18 -0.0506 0.036 -1.409 0.159 -0.121 0.020
x19 0.0953 0.017 5.682 0.000 0.062 0.128
------------------------------------------------------------------------------
y=2 coef std err z P>|z| [0.025 0.975]
------------------------------------------------------------------------------
x1 0.0354 0.025 1.411 0.158 -0.014 0.084
x2 0.0003 0.000 1.996 0.046 5.62e-06 0.001
x3 3.3663 3.177 1.060 0.289 -2.860 9.593
x4 16.6473 8.483 1.962 0.050 0.021 33.273
x5 0.0507 0.026 1.963 0.050 7.82e-05 0.101
x6 0.3423 0.278 1.232 0.218 -0.202 0.887
x7 0.0274 0.026 1.051 0.293 -0.024 0.079
x8 0.0998 0.071 1.397 0.162 -0.040 0.240
x9 -0.0231 0.049 -0.466 0.641 -0.120 0.074
x10 0.0126 0.006 1.969 0.049 5.65e-05 0.025
x11 0.2219 0.129 1.720 0.085 -0.031 0.475
x12 -0.0002 8.6e-05 -2.286 0.022 -0.000 -2.8e-05
x13 -0.0022 0.001 -2.591 0.010 -0.004 -0.001
x14 0.0001 5.35e-05 2.313 0.021 1.89e-05 0.000
x15 -0.0018 0.001 -2.209 0.027 -0.003 -0.000
x16 6.439e-05 0.000 0.468 0.640 -0.000 0.000
x17 -0.8636 0.047 -18.523 0.000 -0.955 -0.772
x18 1.7166 4.104 0.418 0.676 -6.328 9.761
x19 0.0713 0.052 1.375 0.169 -0.030 0.173
==============================================================================
MNLogit Regression, training set, fold 21 : precision recall f1-score support
0 0.89 0.78 0.83 3679
1 0.76 0.83 0.80 2738
2 0.97 1.00 0.98 7435
accuracy 0.91 13852
macro avg 0.87 0.87 0.87 13852
weighted avg 0.91 0.91 0.90 13852
MNLogit Regression, validation set, fold 21 : precision recall f1-score support
0 0.88 0.78 0.83 920
1 0.77 0.82 0.79 685
2 0.97 1.00 0.98 1859
accuracy 0.90 3464
macro avg 0.87 0.86 0.87 3464
weighted avg 0.90 0.90 0.90 3464
shape xtrain: (13853, 19)
shape ytrain: (13853,)
C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\discrete\discrete_model.py:2219: RuntimeWarning: overflow encountered in exp
eXB = np.column_stack((np.ones(len(X)), np.exp(X)))
C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\discrete\discrete_model.py:2220: RuntimeWarning: invalid value encountered in true_divide
return eXB/eXB.sum(1)[:,None]
C:\ProgramData\Anaconda3\lib\site-packages\statsmodels\base\optimizer.py:300: RuntimeWarning: invalid value encountered in greater
oldparams) > tol)):
Optimization terminated successfully.
Current function value: nan
Iterations 6
MNLogit Regression Results
==============================================================================
Dep. Variable: y No. Observations: 13853
Model: MNLogit Df Residuals: 13815
Method: MLE Df Model: 36
Date: Thu, 13 Aug 2020 Pseudo R-squ.: nan
Time: 23:04:10 Log-Likelihood: nan
converged: True LL-Null: -13944.
Covariance Type: cluster LLR p-value: nan
==============================================================================
y=1 coef std err z P>|z| [0.025 0.975]
------------------------------------------------------------------------------
x1 nan nan nan nan nan nan
x2 nan nan nan nan nan nan
x3 nan nan nan nan nan nan
x4 nan nan nan nan nan nan
x5 nan nan nan nan nan nan
x6 nan nan nan nan nan nan
x7 nan nan nan nan nan nan
x8 nan nan nan nan nan nan
x9 nan nan nan nan nan nan
x10 nan nan nan nan nan nan
x11 nan nan nan nan nan nan
x12 nan nan nan nan nan nan
x13 nan nan nan nan nan nan
x14 nan nan nan nan nan nan
x15 nan nan nan nan nan nan
x16 nan nan nan nan nan nan
x17 nan nan nan nan nan nan
x18 nan nan nan nan nan nan
x19 nan nan nan nan nan nan
------------------------------------------------------------------------------
y=2 coef std err z P>|z| [0.025 0.975]
------------------------------------------------------------------------------
x1 nan nan nan nan nan nan
x2 nan nan nan nan nan nan
x3 nan nan nan nan nan nan
x4 nan nan nan nan nan nan
x5 nan nan nan nan nan nan
x6 nan nan nan nan nan nan
x7 nan nan nan nan nan nan
x8 nan nan nan nan nan nan
x9 nan nan nan nan nan nan
x10 nan nan nan nan nan nan
x11 nan nan nan nan nan nan
x12 nan nan nan nan nan nan
x13 nan nan nan nan nan nan
x14 nan nan nan nan nan nan
x15 nan nan nan nan nan nan
x16 nan nan nan nan nan nan
x17 nan nan nan nan nan nan
x18 nan nan nan nan nan nan
x19 nan nan nan nan nan nan
==============================================================================
__main__:42: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False, but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.
__main__:44: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False, but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.
C:\ProgramData\Anaconda3\lib\site-packages\sklearn\metrics\_classification.py:1272: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, msg_start, len(result))
__main__:54: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False, but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.
__main__:56: DeprecationWarning: The truth value of an empty array is ambiguous. Returning False, but in future this will result in an error. Use `array.size > 0` to check that an array is not empty.
C:\ProgramData\Anaconda3\lib\site-packages\sklearn\metrics\_classification.py:1272: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, msg_start, len(result))
MNLogit Regression, training set, fold 21 : precision recall f1-score support
0 0.00 0.00 0.00 3679
1 0.00 0.00 0.00 2739
2 0.54 1.00 0.70 7435
accuracy 0.54 13853
macro avg 0.18 0.33 0.23 13853
weighted avg 0.29 0.54 0.37 13853
MNLogit Regression, validation set, fold 21 : precision recall f1-score support
0 0.00 0.00 0.00 920
1 0.00 0.00 0.00 684
2 0.54 1.00 0.70 1859
accuracy 0.54 3463
macro avg 0.18 0.33 0.23 3463
weighted avg 0.29 0.54 0.38 3463
我不知道这里会发生什么,因为什么都没有真正改变,只有因变量中的值。