带有sklearn LogisticRegression和RandomForest模型的Predict()始终预测少数群体类别(1)

时间:2018-08-22 14:03:53

标签: python machine-learning scikit-learn logistic-regression

我正在建立一个Logistic回归模型,以仅150个观测值的数据集来预测交易是否有效(1)(0)。我的数据在两个类之间的分配情况如下:

  • 106个观察值为0(无效)
  • 44个观察值为1(有效)

我正在使用两个预测变量(均为数值)。尽管数据大多为0,但我的分类器只为我的测试集中的每笔交易预测1,即使大多数交易应为0。分类器从不为任何观察输出0。

这是我的完整代码:

# Logistic Regression
import numpy as np
import pandas as pd
from pandas import Series, DataFrame

import scipy
from scipy.stats import spearmanr
from pylab import rcParams
import seaborn as sb
import matplotlib.pyplot as plt
import sklearn
from sklearn.preprocessing import scale
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn import metrics
from sklearn import preprocessing

address = "dummy_csv-150.csv"
trades = pd.read_csv(address)
trades.columns=['location','app','el','rp','rule1','rule2','rule3','validity','transactions']
trades.head()

trade_data = trades.ix[:,(1,8)].values
trade_data_names = ['app','transactions']

# set dependent/response variable
y = trades.ix[:,7].values

# center around the data mean
X= scale(trade_data)

LogReg = LogisticRegression()

LogReg.fit(X,y)
print(LogReg.score(X,y))

y_pred = LogReg.predict(X)

from sklearn.metrics import classification_report

print(classification_report(y,y_pred)) 

log_prediction = LogReg.predict_log_proba(
    [
       [2, 14],[3,1], [1, 503],[1, 122],[1, 101],[1, 610],[1, 2120],[3, 85],[3, 91],[2, 167],[2, 553],[2, 144]
    ])
prediction = LogReg.predict([[2, 14],[3,1], [1, 503],[1, 122],[1, 101],[1, 610],[1, 2120],[3, 85],[3, 91],[2, 167],[2, 553],[2, 144]])

我的模型定义为:

LogReg = LogisticRegression()  
LogReg.fit(X,y)

其中X看起来像这样:

X = array([[1, 345],
       [1, 222],
       [1, 500],
       [2, 120]]....)

每个观察值的Y仅为0或1。

传递给模型的

Normalized X是这样的:

[[-1.67177659  0.14396503]
 [-1.67177659 -0.14538932]
 [-1.67177659  0.50859856]
 [-1.67177659 -0.3853417 ]
 [-1.67177659 -0.43239119]
 [-1.67177659  0.743846  ]
 [-1.67177659  4.32195953]
 [ 0.95657805 -0.46062089]
 [ 0.95657805 -0.45591594]
 [ 0.95657805 -0.37828428]
 [ 0.95657805 -0.52884264]
 [ 0.95657805 -0.20420118]
 [ 0.95657805 -0.63705646]
 [ 0.95657805 -0.65587626]
 [ 0.95657805 -0.66763863]
 [-0.35759927 -0.25125067]
 [-0.35759927  0.60975496]
 [-0.35759927 -0.33358727]
 [-0.35759927 -0.20420118]
 [-0.35759927  1.37195666]
 [-0.35759927  0.27805607]
 [-0.35759927  0.09456307]
 [-0.35759927  0.03810368]
 [-0.35759927 -0.41121892]
 [-0.35759927 -0.64411389]
 [-0.35759927 -0.69586832]
 [ 0.95657805 -0.57353966]
 [ 0.95657805 -0.57353966]
 [ 0.95657805 -0.53825254]
 [ 0.95657805 -0.53354759]
 [ 0.95657805 -0.52413769]
 [ 0.95657805 -0.57589213]
 [ 0.95657805  0.03810368]
 [ 0.95657805 -0.66293368]
 [ 0.95657805  2.86107294]
 [-1.67177659  0.14396503]
 [-1.67177659 -0.14538932]
 [-1.67177659  0.50859856]
 [-1.67177659 -0.3853417 ]
 [-1.67177659 -0.43239119]
 [-1.67177659  0.743846  ]
 [-1.67177659  4.32195953]
 [ 0.95657805 -0.46062089]
 [ 0.95657805 -0.45591594]
 [ 0.95657805 -0.37828428]
 [ 0.95657805 -0.52884264]
 [ 0.95657805 -0.20420118]
 [ 0.95657805 -0.63705646]
 [ 0.95657805 -0.65587626]
 [ 0.95657805 -0.66763863]
 [-0.35759927 -0.25125067]
 [-0.35759927  0.60975496]
 [-0.35759927 -0.33358727]
 [-0.35759927 -0.20420118]
 [-0.35759927  1.37195666]
 [-0.35759927  0.27805607]
 [-0.35759927  0.09456307]
 [-0.35759927  0.03810368]
 [-0.35759927 -0.41121892]
 [-0.35759927 -0.64411389]
 [-0.35759927 -0.69586832]
 [ 0.95657805 -0.57353966]
 [ 0.95657805 -0.57353966]
 [ 0.95657805 -0.53825254]
 [ 0.95657805 -0.53354759]
 [ 0.95657805 -0.52413769]
 [ 0.95657805 -0.57589213]
 [ 0.95657805  0.03810368]
 [ 0.95657805 -0.66293368]
 [ 0.95657805  2.86107294]
 [-1.67177659  0.14396503]
 [-1.67177659 -0.14538932]
 [-1.67177659  0.50859856]
 [-1.67177659 -0.3853417 ]
 [-1.67177659 -0.43239119]
 [-1.67177659  0.743846  ]
 [-1.67177659  4.32195953]
 [ 0.95657805 -0.46062089]
 [ 0.95657805 -0.45591594]
 [ 0.95657805 -0.37828428]
 [ 0.95657805 -0.52884264]
 [ 0.95657805 -0.20420118]
 [ 0.95657805 -0.63705646]
 [ 0.95657805 -0.65587626]
 [ 0.95657805 -0.66763863]
 [-0.35759927 -0.25125067]
 [-0.35759927  0.60975496]
 [-0.35759927 -0.33358727]
 [-0.35759927 -0.20420118]
 [-0.35759927  1.37195666]
 [-0.35759927  0.27805607]
 [-0.35759927  0.09456307]
 [-0.35759927  0.03810368]
 [-0.35759927 -0.41121892]
 [-0.35759927 -0.64411389]
 [-0.35759927 -0.69586832]
 [ 0.95657805 -0.57353966]
 [ 0.95657805 -0.57353966]
 [ 0.95657805 -0.53825254]
 [ 0.95657805 -0.53354759]
 [ 0.95657805 -0.52413769]
 [ 0.95657805 -0.57589213]
 [ 0.95657805  0.03810368]
 [ 0.95657805 -0.66293368]
 [ 0.95657805  2.86107294]
 [-1.67177659  0.14396503]
 [-1.67177659 -0.14538932]
 [-1.67177659  0.50859856]
 [-1.67177659 -0.3853417 ]
 [-1.67177659 -0.43239119]
 [-1.67177659  0.743846  ]
 [-1.67177659  4.32195953]
 [ 0.95657805 -0.46062089]
 [ 0.95657805 -0.45591594]
 [ 0.95657805 -0.37828428]
 [ 0.95657805 -0.52884264]
 [ 0.95657805 -0.20420118]
 [ 0.95657805 -0.63705646]
 [ 0.95657805 -0.65587626]
 [ 0.95657805 -0.66763863]
 [-0.35759927 -0.25125067]
 [-0.35759927  0.60975496]
 [-0.35759927 -0.33358727]
 [-0.35759927 -0.20420118]
 [-0.35759927  1.37195666]
 [-0.35759927  0.27805607]
 [-0.35759927  0.09456307]
 [-0.35759927  0.03810368]
 [-0.35759927 -0.41121892]
 [-0.35759927 -0.64411389]
 [-0.35759927 -0.69586832]
 [ 0.95657805 -0.57353966]
 [ 0.95657805 -0.57353966]
 [ 0.95657805 -0.53825254]
 [ 0.95657805 -0.53354759]
 [ 0.95657805 -0.52413769]
 [ 0.95657805 -0.57589213]
 [ 0.95657805  0.03810368]
 [ 0.95657805 -0.66293368]
 [ 0.95657805  2.86107294]
 [-0.35759927  0.60975496]
 [-0.35759927 -0.33358727]
 [-0.35759927 -0.20420118]
 [-0.35759927  1.37195666]
 [-0.35759927  0.27805607]
 [-0.35759927  0.09456307]
 [-0.35759927  0.03810368]]

Y是:

[0 0 0 0 0 0 1 1 0 0 0 1 1 1 1 0 0 0 0 1 0 0 0 0 1 1 0 0 0 0 0 0 1 1 1 0 0
 0 0 0 0 1 1 0 0 0 1 1 1 1 0 0 0 0 1 0 0 0 0 1 1 0 0 0 0 0 0 1 1 1 0 0 0 0
 0 0 1 1 0 0 0 1 1 1 1 0 0 0 0 1 0 0 0 0 1 1 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0
 1 1 0 0 0 1 1 1 1 0 0 0 0 1 0 0 0 0 1 1 0 0 0 0 0 0 1 1 1 0 0 0 1 0 0 0]

模型指标是:

             precision    recall  f1-score   support

          0       0.78      1.00      0.88        98
          1       1.00      0.43      0.60        49

avg / total       0.85      0.81      0.78       147

得分为 0.80

当我运行model.predict_log_proba(test_data)时,我得到如下所示的概率区间:

array([[ -1.10164032e+01,  -1.64301095e-05],
       [ -2.06326947e+00,  -1.35863187e-01],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00],
       [            -inf,   0.00000000e+00]])

我的测试集为,除2之外的所有其他项都应为0,但它们都归为1。对于每个测试集都会发生这种情况,即使是那些具有训练模型价值的测试集。

[2, 14],[3,1], [1, 503],[1, 122],[1, 101],[1, 610],[1, 2120],[3, 85],[3, 91],[2, 167],[2, 553],[2, 144]

我在这里找到了一个类似的问题:https://stats.stackexchange.com/questions/168929/logistic-regression-is-predicting-all-1-and-no-0,但是在这个问题中,问题似乎在于数据主要是1,因此模型可以输出1s是有意义的。我的情况则相反,因为火车数据大部分为0,但由于某种原因,即使1相对较少,我的模型也始终为所有输出1。我还尝试了一个随机森林分类器,以查看模型是否错误,但是发生了同样的事情。也许这是我的数据,但我不知道它有什么问题,因为它符合所有假设。

可能是什么问题?数据满足逻辑模型的所有假设(两个预测变量都是独立的,输出是二进制的,没有丢失的数据点)。任何建议表示赞赏。

1 个答案:

答案 0 :(得分:1)

您没有扩展test数据。您这样做时正确地调整了火车数据:

X= scale(trade_data)

训练模型后,您不会对测试数据进行相同的操作

log_prediction = LogReg.predict_log_proba(
[
   [2, 14],[3,1], [1, 503],[1, 122],[1, 101],[1, 610],[1, 2120],[3, 85],[3, 91],[2, 167],[2, 553],[2, 144]
])

您的模型系数是建立在期望归一化输入的基础上的。您的测试数据未标准化。该模型的任何正系数都将乘以一个巨大的数字,因为您的数据无法缩放,可能会使您的预测值全部变为1。

一般规则是,您对训练集所做的任何转换也应同样在测​​试集上进行。您还应该对训练集和测试集应用相同的变换。代替:

X = scale(trade_data)

您应该像这样从训练数据中创建缩放器:

scaler = StandardScaler().fit(trade_date)
X = scaler.transform(trade_data)

然后稍后将该缩放器应用于您的test数据:

scaled_test = scaler.transform(test_x)