Spark ML多类分类问题的评估指标

时间:2018-12-26 17:14:29

标签: apache-spark-ml

我正在寻找一个使用Spark-Scala的多类分类示例,但是我还找不到一个。具体来说,我想训练一个分类模型,并查看所有有关训练和测试数据的指标。

Spark ML(基于DataFrame的API)是否支持多类问题的混淆矩阵?

我正在寻找Spark v 2.2及更高版本的示例。一个端到端的示例将非常有用。我在这里找不到混淆矩阵评估-

https://spark.apache.org/docs/2.3.0/ml-classification-regression.html

2 个答案:

答案 0 :(得分:0)

应该是这样:

val metrics = new MulticlassMetrics(predictionAndLabels)
println(metrics.confusionMatrix)

分类指标在这里: https://spark.apache.org/docs/2.3.0/mllib-evaluation-metrics.html

答案 1 :(得分:0)

假设model是您训练的模型,而test是测试集, 这是用于计算python中的混淆矩阵的代码段:

import pandas as pd
from pyspark.mllib.evaluation import MulticlassMetrics
predictionAndLabels = model.transform(test).select('label', 'prediction')
metrics = MulticlassMetrics(predictionAndLabels.rdd.map(lambda x: tuple(map(float, x))))

confusion_matrix = metrics.confusionMatrix().toArray()
labels = [int(l) for l in metrics.call('labels')]
confusion_matrix = pd.DataFrame(confusion_matrix , index=labels, columns=labels)

请注意,由于某些原因metrics.labels未在pyspark中实现,因此我们直接调用scala后端