DecisionTreeClassifier predict_proba返回0或1

时间:2018-01-12 05:16:55

标签: python decision-tree roc sklearn-pandas

我试图使用分类的决策树来根据某些参数识别两个类(重命名为0和1)。我使用数据集训练它,然后在“测试数据集”上运行它。当我尝试计算测试数据集中每个数据点的概率时,它仅返回0或1。我想知道问题是什么。

以下是示例代码:

clf=tree.DecisionTreeClassifier(random_state=0) trained=clf.fit(data,identifier) # training data where identifier is 0 or 1 predict=trained.predict(test_data) 结果如下:

In [9]: predict

Out[9]: 
array([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0,
       1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0,
       0, 0, 1, 0, 0, 0])

In [10]: trained.predict_proba(test_data)[:,1]

Out[10]: 
array([ 0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  1.,  1.,
        0.,  0.,  1.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  1.,  1.,  0.,
        1.,  0.,  1.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  1.,  1.,  1.,
        0.,  1.,  1.,  0.,  0.,  1.,  0.,  0.,  0.,  1.,  0.,  0.,  0.])

我想生成ROC,此时只返回FPR / TPR的3个数据点。

以下是完整的数据集: 标识符是“数据”的最后一列。

培训数据:

Spectral_Index,W1-W2,W2-W3,HR0.3-100,HR50-2,Gamma,Class
1.4304664,0.61,2.18,0.3819051,0.99992716,1.93,0
1.6969398,0.54,1.93,0.66479063,0.9999814,2.11,0
2.233997,1.02,3.18,0.55532146,0.9999979,2.07,0
2.230639,0.77,2.34,0.0012237767,1.0,1.81,0
1.7325432,0.71,2.27,0.34395835,1.0,1.9,0
1.8728518,0.8,2.14,0.4255796,1.0,1.96,0
1.9818852,0.7,2.18,-0.08978904,1.0,1.66,0
2.3864453,0.95,2.51,0.109010585,0.98401743,1.81,0
2.5911317,0.94,2.49,0.60381645,0.99991965,2.03,0
1.9564596,0.81,2.29,0.3843,0.9999495,2.08,0
2.1506176,0.93,2.62,0.28551856,0.9999999,1.91,0
1.9069784,0.62,1.76,0.041608978,1.0,1.86,0
1.6216202,0.77,2.11,-0.14271076,1.0,1.7,0
2.276335,0.68,2.14,0.40399882,1.0,2.06,0
2.2430172,1.0,2.94,0.61844856,1.0,2.12,0
1.0226197,0.66,2.07,-0.14886126,1.0,1.84,0
2.2564504,1.06,2.77,0.6974536,0.99844635,2.16,0
2.2819016,0.88,2.37,0.30696234,0.999996,1.86,0
1.4881139,0.7,2.09,0.40853307,1.0,1.82,0
2.4640048,0.9,2.39,0.35103577,1.0,2.02,0
2.656071,0.72,2.29,0.21568911,0.9999046,2.11,0
1.7204628,0.62,2.01,0.19794853,1.0,1.8,0
1.9134961,0.86,2.27,0.37281907,1.0,1.94,0
1.3061943,0.67,2.01,0.3463318,0.99999976,1.86,0
1.8845558,0.64,2.01,0.12364135,0.9999834,1.84,0
2.4409518,1.12,3.31,0.7502838,1.0,2.17,0
1.9501582,0.85,2.34,0.29961613,0.9999974,1.92,0
2.1314192,1.03,2.62,0.69623667,1.0,2.28,0
1.7345899,0.69,2.61,0.38524705,0.99999887,2.09,0
1.7095753,0.75,2.08,0.21696341,0.9999987,1.95,0
1.9115254,0.83,2.17,-0.046689913,1.0,1.85,0
1.565369,0.67,2.01,-0.04827315,0.9999915,1.79,0
2.2971635,0.59,2.1,0.35741857,1.0,2.0,0
3.042759,1.06,2.94,0.70878696,0.9999844,2.15,0
2.340724,0.96,2.74,0.42822766,0.99999416,1.97,0
1.8552977,0.74,2.09,0.07262661,1.0,1.69,0
2.0324602,0.66,2.05,-0.07643526,0.9999982,1.83,0
1.8508979,0.67,1.96,0.054557554,0.99997455,1.75,0
2.7983437,0.96,2.58,0.8554537,0.9999992,2.2,0
2.1728642,1.09,3.05,0.61488354,1.0,2.04,0
3.113785,0.66,1.85,0.48011553,0.99995273,1.95,0
3.0665417,0.78,2.19,0.27814054,1.0,1.86,0
2.0060341,0.83,2.39,0.20785762,0.9999502,1.85,0
2.1786506,0.57,2.0,0.33096096,1.0,1.91,0
1.823961,0.72,1.96,-0.103285044,1.0,1.6,0
1.612012,0.68,2.15,-0.3136376,0.65517294,1.52,0
2.1615896,0.87,2.4,0.47535577,1.0,2.04,0
2.3053634,1.06,2.92,0.67040676,0.9991328,2.15,0
1.7525402,0.73,2.12,0.25563625,0.9999979,1.92,0
2.7306526,0.91,2.35,0.68943393,-0.4308276,2.1,0
2.2549937,1.07,2.91,0.6077795,0.9999626,2.04,0
2.0924683,0.69,2.04,-0.068183094,0.3497915,1.77,0
2.210627,0.84,2.09,0.6309954,0.99999976,1.99,0
2.4609168,0.67,2.08,0.29552716,0.99964327,1.96,0
2.5169518,0.84,2.45,0.35437247,0.9999745,1.92,0
2.1841373,0.9,2.51,0.5617463,1.0,2.15,0
3.0673068,0.8,2.22,0.17641401,1.0,1.9,0
2.6202004,0.97,2.47,0.36663872,1.0,2.03,0
1.9694642,0.95,2.54,0.33140072,0.99998665,2.04,0
1.8766946,0.84,2.32,-0.024992371,0.99999803,1.94,0
2.9352057,1.2,2.96,0.6385377,0.9951195,2.18,0
1.4075257,0.86,2.27,0.046303034,0.9999998,1.81,0
1.8769667,0.6,2.0,0.08842805,0.15410244,1.83,0
1.2585826,0.71,1.96,0.005930161,0.78259146,1.72,0
2.2046561,0.9,2.37,0.62021697,1.0,2.07,0
1.0217602,0.49,1.89,-0.26944694,0.9999997,1.66,0
2.1021683,1.05,2.78,0.5306551,1.0,2.14,0
2.4789429,0.94,2.52,0.34224525,0.9999965,2.01,0
2.1449182,0.8,2.32,0.37609425,0.9997282,2.25,0
2.7071185,0.83,2.36,0.75363404,1.0,2.31,0
1.8445525,1.04,2.76,0.6075378,0.88632137,2.14,0
1.6024263,1.09,2.63,0.64461184,1.0,2.18,0
2.0292685,0.53,2.15,0.090091705,1.0,1.92,0
2.0858748,0.71,1.86,0.14351326,0.9999994,1.88,0
2.1292083,0.81,2.31,0.33257455,1.0,1.95,0
1.6344122,0.84,2.38,0.6371139,0.9999998,2.11,0
1.7532507,0.75,2.04,0.16182575,1.0,1.78,0
2.2479355,0.97,2.72,0.41953298,1.0,2.04,0
2.5790315,1.07,2.96,0.7216893,0.9999953,2.11,0
3.0039942,1.03,2.44,0.8042694,0.9998856,2.25,1
3.7599833,1.16,3.23,0.9095345,0.66683024,2.39,1
2.8912013,1.05,2.67,0.85215354,0.9967052,2.27,1
3.8784094,1.11,3.18,0.6971026,1.0,2.19,1
2.1862392,1.13,2.7,0.65855825,1.0,2.28,1
2.7684402,1.16,2.79,0.9261603,-0.9540385,2.35,1
1.7551649,0.56,2.18,0.23092282,1.0,1.98,1
2.804592,1.13,2.98,0.84827685,1.0,2.3,1
1.9874831,1.0,2.98,0.87599415,1.0,2.21,1
2.5059428,1.16,2.79,0.97649753,0.9997586,2.42,1
2.812127,1.12,3.11,0.87392867,1.0,2.21,1
2.9445121,1.06,3.17,0.8849491,1.0,2.41,1
2.7388847,1.11,2.78,0.84986275,0.96669436,2.32,1
2.1416433,1.1,3.61,0.7671358,0.9999998,2.29,1
2.3661094,1.05,3.16,0.73194104,0.99990827,2.14,1
2.761189,1.09,2.81,0.7681978,-0.99955946,2.23,1
2.6658804,1.02,3.36,0.8036201,0.98403203,2.28,1
2.720667,0.99,2.78,0.97055733,0.9781505,2.48,1
2.6812658,0.98,3.05,0.73290765,1.0,2.09,1
1.4784714,0.62,1.97,0.418,1.0,2.02,0
1.7488811,0.7,2.05,0.418,0.99999624,2.02,0

测试数据:

Spectral_Index,W1-W2,W2-W3,HR0.3-100,HR50-2,Gamma
1.6724254,0.95,2.58,0.92031854,1.0,2.15
2.552926,0.93,2.74,0.63588345,-0.30092865,2.18
2.5737462,0.86,2.22,0.43023747,1.0,2.08
2.1701677,0.62,2.19,0.6892167,1.0,2.15
3.6152358,0.96,2.58,0.67760235,0.99704355,2.06
3.6193092,0.82,2.34,0.4083981,0.9973078,2.04
2.0209844,1.02,2.86,0.8595182,-0.9979041,2.36
2.166221,1.07,3.0,0.7177616,-0.99961376,2.3
2.7933478,0.94,2.4,0.678935,1.0,2.12
2.2969048,0.86,2.29,0.18689133,1.0,1.96
3.1255674,1.15,2.77,0.9290483,0.6387009,2.28
2.3548958,1.01,2.46,0.75331503,-1.0,2.21
3.9791226,1.15,3.04,0.87006325,-0.99919724,2.43
2.3430493,0.85,2.42,0.81132597,-0.9999996,2.04
3.7431624,0.79,2.57,0.704,0.99952716,2.20784
3.1846259,1.14,2.85,0.9104803,0.99891067,2.3
3.1416001,0.73,2.26,0.5679769,1.0,1.98
2.670179,0.85,2.66,0.7376513,0.97939825,2.1
3.010911,0.79,2.38,0.21750104,0.21187924,1.82
1.4430648,0.9,2.38,0.7361963,0.999758,2.11
2.8149416,1.07,2.62,0.94750744,0.9967568,2.4
3.8395922,1.09,2.91,0.27485812,0.99887043,2.05
3.1686394,0.66,2.11,0.529385,1.0,1.9
3.190167,1.09,3.1,0.8501991,0.9507157,2.23
3.8597586,1.13,3.64,0.89043206,0.17880388,2.42
2.1516426,0.85,2.24,0.6673518,0.9985168,2.2
2.1318088,0.98,2.64,0.85542095,1.0,2.22
1.6740437,0.97,2.99,0.86632746,0.9983954,2.41
4.273427,1.01,2.71,0.8941501,0.64256436,2.47
2.284782,0.92,2.7,0.5820462,0.6981752,2.1
3.343603,1.06,2.84,0.6901738,0.83269715,2.13
5.766362,1.2,3.74,0.99009913,0.99998844,2.49
2.1547525,0.95,3.02,0.75229234,0.99604213,2.57
2.9853358,0.91,2.37,0.62881154,-0.98792726,2.06
2.8614197,0.82,2.15,0.75643075,1.0,2.19
3.6815813,1.14,3.24,0.8886577,-0.030438267,2.39
4.539201,1.17,2.83,0.93989134,0.23378997,2.55
3.35261,1.1,2.73,0.9184936,0.9998006,2.41
3.6697345,1.16,3.57,0.9515105,0.9999988,2.43
1.9781204,0.91,2.85,-0.06649571,0.9999991,1.7
2.6618617,1.1,3.24,0.8348949,-0.9834342,2.29
3.8140056,1.18,3.25,0.8766021,1.0,2.39
2.1926181,1.05,2.3,0.6880097,1.0,2.3
2.0248337,0.83,2.29,0.3604591,0.46159065,2.05
3.904931,1.13,2.46,0.9100119,1.0,2.32
1.9945884,0.94,2.5,0.4632657,0.9869119,2.05
3.3342967,1.1,3.04,0.51323855,-0.5262294,2.23
2.3138714,0.91,2.36,0.90414697,0.9999977,2.29
2.3118904,1.04,3.01,0.87289846,0.998577,2.29
2.246307,1.07,2.72,0.6147379,0.9999993,2.11
1.6369493,0.89,2.34,0.61421084,0.9997295,2.22
3.6198807,0.93,2.62,0.7463702,0.9994778,2.07

1 个答案:

答案 0 :(得分:1)

没有问题 - 树的行为完全符合预期。

决策树根据属于给定叶子的每个类的样本数计算类概率。

documentation说:

  

控制树木大小的参数的默认值(例如max_depthmin_samples_leaf等)会导致完全成长和未修剪的树木

即。树长大,直到完全(超过)适合训练数据。这意味着每个叶子中的所有训练样本属于同一类,并且测试样本与该类匹配(p = 1)或不匹配(p = 0)。