随机森林的函数predict_proba()如何工作?

时间:2017-06-13 13:49:58

标签: python scikit-learn random-forest predict

我正在python 3.4工作,我尝试使用sklearn random forest模型做一些预测。 我首先生成confusion matrix并计算精度,如下所述:

accuracy= confusion_matrix.trace() / len(testing_dataframe)

这给了我一个大约0.18的准确度,这还不够好

因此,为了获得更高的结果,我使用predict_proba()方法生成了6个命题(我有大约1000个标签),并按如下方式计算准确度:

Boolean_array=np.zeros(len(testing_dataframe))
if label[i] is in propositions[i]:
   Boolean_array[i]= True
else:
   Boolean_array[i]= False

accuracy=numpy.mean(propositions)

这给了我一个大约0.16的准确度,这很奇怪,因为我期待它更高。

这就是为什么我想知道predict()predict_proba()是否相关,我的意思是:

predict(X)=predict_proba(X).max()

是否有人知道predict_proba如何运作,为什么我的准确度会降低?

修改

以下是我的数据的一小部分示例:

功能DataFrame

    f0              f1          f2        f3       f4             f5              f6       f7        f8             f9       ...        f90        f91       f92         f93         f94         f95         f96          f97        f98        f99
0   0.535271    0.025914    0.562412    -0.419993   0.681204    0.454155    -0.011025   0.038262    0.527357    -0.050546   ...     0.030861    -0.610272   0.317212    -0.418725   0.177978    -0.417070   0.053116    0.395432    0.291500    1.316563
1   -0.901409   -0.516550   -0.332680   1.256825    0.493730    -0.874850   -0.709655   1.452980    -1.383530   0.168827    ...     0.035605    0.445295    -1.631235   0.297015    0.204900    -0.064275   -0.739375   0.583020    -0.949815   1.621750
2   0.160291    0.817350    -0.751626   0.226337    0.727847    0.086183    -0.376025   1.272175    0.031435    0.418815    ...     0.025983    -0.118028   -0.608380   -0.626938   0.212298    -0.948225   0.461030    0.565805    0.860127    1.669100
3   -0.476125   -0.670245   0.146448    0.242735    1.179170    -1.206770   -0.374161   0.474205    0.605703    -0.915345   ...     -0.737470   1.207160    1.749050    1.401800    -0.694545   0.621000    0.133150    0.117960    -0.712190   -0.829515
4   1.592632    0.406484    -0.582462   -0.273904   -0.043725   0.473370    -0.561084   -0.586366   0.395254    -0.137352   ...     -1.805612   -1.325500   -0.304604   0.031349    1.496978    -0.026860   0.023190    1.226166    1.279988    1.158928

标签DataFrame

0    1360
1     268
2     269
3    1047
4    1344
Name: labels, dtype: int64

提前致谢

0 个答案:

没有答案