我正在使用pyspark随机森林分类器,并希望一旦返回预测就根据预测创建熊猫数据框。当我尝试这样做时,会发生最奇怪的异常。这是我的代码:
random_forest = RandomForestClassifier(labelCol = 'label', featuresCol = 'features', maxDepth = 4, impurity = 'entropy', numTrees = 10, maxBins = 250)
rf_model = random_forest.fit(training_data)
predictions = rf_model.transform(test_data)
# Where exception happens
df = predictions.select('rawPrediction', 'label', 'prediction').where((predictions.label == '1.0') & (predictions.prediction == '0.0')).toPandas()
# The code that works fine
label_pred_train = predictions.select('label', 'prediction')
print label_pred_train.rdd.zipWithIndex().countByKey()
当我尝试过滤预测并选择它们的一个子集以转换为熊猫数据帧时,就会发生问题。当我将toPandas
替换为count
,collect
等时,会发生同样的异常。最让我惊讶的是,当我删除它们并执行以下使用rdd的行时,计算一切正常,并返回结果。我已经阅读了几篇有关StringIndexer
的问题以及如何使用handleInvalid = 'keep'
的帖子,但不幸的是,我正在使用spark 2.1来运行它,老实说,它与{{ 1}},因为我能够拟合,变换并从模型中获得预测。我在这里可能会缺少什么吗?
这是完整的例外:
StringIndexer
答案 0 :(得分:0)
使用此
df = predictions.select('rawPrediction', 'label', 'prediction').where((predictions.label == 1.0 & (predictions.prediction == 0.0)).toPandas()
代替
df = predictions.select('rawPrediction', 'label', 'prediction').where((predictions.label == '1.0') & (predictions.prediction == '0.0')).toPandas()
不是使用'1.0'而是使用1.0,最有可能是因为它查找的是整数类型的字符串类型,而且我找不到它为它分配了null并使错误看不见的标签的原因。