Scala中的Spark:如何将两列与它们不同的位置数进行比较?

时间:2017-01-14 05:53:28

标签: scala apache-spark dataframe

我在Spark中有一个名为df的DataFrame。我在一些功能上训练了机器学习模型,只想计算labelprediction列之间的准确度。

scala> df.columns
res32: Array[String] = Array(feature1, feature2, label, prediction)

这在numpy中会非常简单:

accuracy = np.sum(df.label == df.prediction) / float(len(df))

使用Scala在Spark中有同样简单的方法吗?

我还应该提到我对Scala来说是全新的。

1 个答案:

答案 0 :(得分:1)

必需的导入:

import org.apache.spark.sql.functions.avg
import spark.implicits._

示例数据:

val df = Seq((0, 0), (1,  0), (1, 1), (1, 1)).toDF("label", "prediction")

解决方案:

df.select(avg(($"label" === $"prediction").cast("integer")))

结果:

+--------------------------------------+
|avg(CAST((label = prediction) AS INT))|
+--------------------------------------+
|                                  0.75|
+--------------------------------------+

添加:

.as[Double].first

.first.getDouble(0)

如果您需要本地值。如果你想计算替换:

avg(($"label" === $"prediction").cast("integer"))

sum(($"label" === $"prediction").cast("integer"))

count(when($"label" === $"prediction", true))