Spark 1.3.1中的averange randomForest预测

时间:2015-12-14 13:15:19

标签: python apache-spark

我试图计算Spark 1.3.1中的randomForest预测的平均值,因为所有树的预测概率仅在未来版本中可用。

直到现在我能做的最好的事情是使用以下功能:

def calculaProbs(dados, modelRF):
    trees = modelRF._java_model.trees()
    nTrees = modelRF.numTrees()
    nPontos = dados.count()
    predictions = np.zeros(nPontos)
    for i in range(nTrees):
        dtm = DecisionTreeModel(trees[i])
        predictions += np.array(dtm.predict(dados.map(lambda x: x.features)).collect())
    predictions = predictions/nTrees
    return predictions

此代码运行速度太慢,正如预期的那样,因为我正在从每个树收集预测并将其添加到Driver中。 我不能将$ dtm.predit()$放在这个版本的Spark中的Map操作中。以下是文档中的注释:"注意:在Python中,目前无法在RDD转换或操作中使用预测。直接在RDD上调用预测。"

任何改善表现的想法?如何在不将其值收集到矢量的情况下添加2个RDD的值?

0 个答案:

没有答案