如何转换sklearn中任何分类器的predict()方法的输出?

时间:2018-01-15 14:10:56

标签: python machine-learning scikit-learn classification linear-regression

我开始在python sklearn模块中学习不同类型分类器的应用。 clf_LR.predict(X_predict)预测' Loan_Status'测试数据。在培训数据中,根据贷款批准,它是1或0。但是预测给出了一个大约0和1的浮动值的数组。我想将这些值转换为最接近的1或0。

    #regression
    X = np.array(train_data.drop(['Loan_Status'],1))
    y = np.array(train_data['Loan_Status'])
    X_predict = np.array(test_data)
    clf_LR = LinearRegression()
    clf_LR.fit(X,y)
    accuracy = clf_LR.score(X,y)
    clf_LR.predict(X_predict)

输出结果为:

array([ 1.0531505 ,  0.54463698,  0.66512836,  0.91817899,  0.81084038,
    0.4400971 ,  0.05132584,  0.5797642 ,  0.72760712,  0.78624   ,
    0.60043618,  0.79904144,  0.78164806,  0.63140686,  0.66746683,
    0.56799806,  0.62462483, -0.27487531,  0.77595855,  0.62112923,
    0.42499627,  0.21962665,  0.73747749,  0.62580336,  1.08242647,
    0.60546731,  0.58980138,  0.68778534,  0.80729382, -0.25906255,
    0.5911749 ,  0.57754607,  0.71869494,  0.7414411 ,  0.79574657,
    1.053294  ,  0.77238618,  0.84663303,  0.93977499,  0.39076889,
    0.79835196, -0.31202102,  0.57969628,  0.6782184 ,  0.62406822,
    0.76141175, -0.14311827,  0.87284553,  0.45152395,  0.70505136,
    0.80529711,  0.88614397,  0.0036123 ,  0.59748637,  1.15082822,
    0.6804735 ,  0.64551666, -0.28882904,  0.71713245,  0.66373934,
    0.5250008 ,  0.81825485,  0.71661801,  0.74462875,  0.66047019,
    0.62186449, -0.2895147 ,  0.78990148, -0.198547  ,  0.02752572,
    1.0440052 ,  0.58668459,  0.82012492,  0.50745345, -0.07448848,
    0.56636204,  0.85462188,  0.4723699 ,  0.5501792 ,  0.91271145,
    0.61796331,  0.47130567,  0.74644572,  0.38340698,  0.65640869,
    0.75736077, -0.23866258,  0.89198235,  0.74552824,  0.58952803,
    0.75363266,  0.44341609,  0.76332621,  0.60706656,  0.548128  ,
   -0.05460422,  0.81488009,  0.51959111,  0.91001994,  0.71223763,
    0.67600868,  0.79102218, -0.00530356,  0.20135057,  0.73923083,
    0.56965262,  0.80045725,  0.67266281,  0.81694555,  0.70263141,
    0.38996739,  0.38449832,  0.77388573,  0.92362979,  0.54006616,
    0.76432229,  0.61683807,  0.44803386,  0.79751796,  0.55321023,
    1.10480386,  1.03004599,  0.54718652,  0.74741632,  0.83907984,
    0.86407637,  1.10821273,  0.6227142 ,  0.94443767, -0.02906777,
    0.68258672,  0.38914101,  0.86936186, -0.17331518,  0.35980983,
   -0.32387964,  0.86583445,  0.5480951 ,  0.5846661 ,  0.96815188,
    0.45474766,  0.54342586,  0.41997578,  0.73069535,  0.05828308,
    0.4716423 ,  0.70579418,  0.76672804,  0.90476146,  0.45363533,
    0.78646442,  0.76841914,  0.77227952,  0.75068078,  0.94713967,
    0.67417191, -0.16948404,  0.80726176,  1.12127705,  0.74715634,
    0.44632464,  0.61668874,  0.6578295 ,  0.60631521,  0.42455094,
    0.65104766, -0.01636441,  0.87456921, -0.24877682,  0.76791838,
    0.85037569,  0.75076961,  0.91323444,  0.27976108,  0.89643734,
    0.14388116,  0.7340059 ,  0.46372024,  0.91726212,  0.43539411,
    0.44859789, -0.04401285,  0.28901989,  0.62105238,  0.56949422,
    0.49728522,  0.65641239,  1.11183953,  0.76159204,  0.55822867,
    0.79752582,  0.72726221,  0.49171728, -0.32777583, -0.30767082,
    0.70702693,  0.91792405,  0.76112155,  0.68748705,  0.6172974 ,
    0.70335159,  0.74522648,  1.01560133,  0.62808723,  0.50816819,
    0.61760714,  0.55879101,  0.50060645,  0.87832261,  0.73523273,
    0.60360986,  0.78153534, -0.2063286 ,  0.85540569,  0.59231311,
    0.75875401,  0.34422049,  0.58667666, -0.14887532,  0.81458285,
    0.90631338,  0.5508966 ,  0.93534451,  0.0048111 ,  0.66506743,
    0.5844512 ,  0.67768398,  0.91190474,  0.39758323,  0.44284897,
    0.47347625,  0.7603246 ,  0.41066447,  0.50419741,  0.74437409,
    0.44916515,  0.14160128,  0.72991652,  1.15215444,  0.50707437,
    0.61020873,  0.8831041 ,  0.78476914,  0.4953215 ,  0.71862044,
    0.66574986,  0.89547805,  0.93534669,  0.57742771,  0.9225718 ,
    0.67209865,  0.34461023,  0.52848926,  0.95846303,  0.88237609,
   -0.01603499,  0.94158916,  0.44069838, -0.17133448,  0.35288583,
    0.55302018,  0.36446662,  0.62047864,  0.3803367 ,  0.60398751,
    0.9152663 ,  0.48237299,  0.05646119, -0.65950771,  0.52644392,
   -0.14182158,  0.65408783, -0.01741803,  0.76022561,  0.70883902,
    0.56782191,  0.66484671,  0.79638622,  0.6668274 ,  0.94365746,
    0.76132423,  0.63407964,  0.43784118,  0.74599199,  0.69594847,
    0.96794245,  0.49120557, -0.30985337,  0.48242465,  0.78788   ,
    0.74562549,  0.61188416, -0.13990599,  0.59192289,  0.52577439,
    0.62118612,  0.47292839,  0.38433912,  0.58535049,  0.61180443,
    0.68363366, -0.17158279, -0.16752298, -0.12006642,  0.11420194,
    0.54435597,  0.76707794,  0.94712879,  0.90341355,  0.41133755,
    0.78063296,  1.06335948,  0.65061658,  0.55463919, -0.16184664,
    0.45612831,  0.2974657 ,  0.74769718,  0.73568274,  0.91792405,
    0.69938454,  0.07815941,  0.73400855,  0.33905491,  0.48330823,
    0.76760269, -0.03303408,  0.64432907,  0.44763337,  0.59214243,
    0.78339532,  0.74755724,  0.70328769,  0.61766433, -0.34196805,
    0.74271219,  0.66617484,  0.75939014,  0.46274977,  0.43760914,
   -0.11568388,  1.12101126,  0.65718951,  0.74632966, -0.3918828 ,
    0.29915035,  0.6155425 ,  0.66089274,  0.8555285 ,  0.54121081,
    0.74758901,  0.84686185,  0.68150433,  0.44953323,  0.71672738,
    0.86416735,  0.97374945,  0.36594854,  0.5508358 ,  0.60524084,
   -0.04479449,  0.56064679,  0.46826815,  0.75353414,  0.63092004,
    0.52340796,  0.36622527,  0.42553235,  0.81877722, -0.03474048,
    0.56185539,  0.57384744,  0.86959987, -0.35002778,  0.59209448,
    0.43892519,  0.83366299,  0.55630127,  0.68092981,  0.79639642,
    0.96289854, -0.15094804,  0.5866888 ,  0.88245453,  0.65447514,
    1.00194182,  0.45130259, -0.16774169,  0.66529484,  0.87330175,
    0.12493249,  0.07427334,  0.79084776,  0.60848656,  0.7706963 ,
    0.76846985,  0.74796571,  0.52316893,  0.62116966,  0.52497383,
    0.05855483,  0.75575428, -0.20233853,  0.77693886,  0.15845594,
    0.88457158,  0.0846857 ,  0.7831948 ,  0.54955829,  0.71151434,
    1.23277406,  0.0153455 ,  0.7111069 ,  0.64140878,  0.69578766,
    0.72386089,  0.3291767 ,  0.8414526 , -0.14267676,  0.93841726,
    0.94248916,  0.61492774,  0.60835432, -0.05542942,  1.01387972,
    0.81980896,  0.39519755,  0.85483256,  0.79124875,  0.46196837,
    0.5157149 , -0.2076404 ,  0.57935033,  0.86477299,  0.62917312,
    0.85446301,  0.40595525,  0.64527099,  0.7452028 ,  0.58527638,
    0.66419528,  0.49120555,  0.83966651,  0.86063059,  0.85615707,
   -0.22704174])

我想将这些值转换为最接近的1或0.有没有办法在这里进行此操作?

2 个答案:

答案 0 :(得分:1)

import numpy as np 
np.round(np.clip(clf_LR.predict(X_predict), 0, 1))  # floats
np.round(np.clip(clf_LR.predict(X_predict), 0, 1)).astype(bool)  # binary

从技术上讲,代码不是就地,但可以转换(使用out参数)!

(未经测试:试试吧!)

答案 1 :(得分:0)

正如在@Pault中所说,你需要的是一个分类器,sklearn有很多分类器! 要使用的分类器的选择取决于许多因素:

sklearn的以下图片可以帮助您选择: The following picture

基本上对于逻辑回归分类器,您可以执行以下操作:

from sklearn.linear_model import LogisticRegression
clf = LogisticRegression(C=1.0, penalty='l1')
clf.fit(X, y)
clf.predict(X_predict) # will give you 0 or 1 as the class