Spark-更改属于数据集中长尾记录的记录的值

时间:2018-06-22 17:43:07

标签: scala apache-spark machine-learning

我正在尝试解决机器学习问题中的数据清理步骤,我应该将长尾中的所有元素归为一个名为“其他”的常见类别。例如,我有一个像这样的数据框:

val df = sc.parallelize(Seq(
(1, "ABC"),
(2, "ABC"),
(3, "123"),
(4, "FPK"),
(5, "FPK"),
(6, "ABC"),
(7, "ABC"),
(8, "980"),
(9, "abc"),
(10, "FPK")
)).toDF("n", "s")

我希望保留类别"ABC""FPK",因为它们出现了几次,但我不想为123,980,abc拥有一个不同的类别,因为它们只出现了一次。所以我想拥有的是:

+---+------+
|  n|     s|
+---+------+
|  1|   ABC|
|  2|   ABC|
|  3|Others|
|  4|   FPK|
|  5|   FPK|
|  6|   ABC|
|  7|   ABC|
|  8|Others|
|  9|Others|
| 10|   FPK|
+---+------+

要实现这一点,我尝试过的是:

val newDF = df.withColumn("s",when($"s".isin("123","980","abc"),"Others").otherwise('s)

这很好。

但是我想以编程的方式决定哪些类别属于长尾,在我的情况下,该类别在原始数据帧中仅出现一次。因此,我编写了此代码,以创建一个数据框,其类别仅出现一次:

val longTail = df.groupBy("s").agg(count("*").alias("cnt")).orderBy($"cnt".desc).filter($"cnt"<2)

+---+---+
|  s|cnt|
+---+---+
|980|  1|
|abc|  1|
|123|  1|
+---+---+

现在,我正在尝试将longTail数据集中的列“ s”的值转换为一个List,以与之前进行硬编码的那个交换。所以我尝试了:

 val ar = longTail.select("s").collect().map(_(0)).toList
  

ar:List [Any] = List(123,980,abc)

但是当我尝试添加ar

val newDF = df.withColumn("s",when($"s".isin(ar),"Others").otherwise('s))

我收到以下错误:

  

java.lang.RuntimeException:不支持的文字类型类   scala.collection.immutable。$ colon $冒号列表(123,980,abc)

我想念什么?

2 个答案:

答案 0 :(得分:3)

这是正确的语法:

scala> df.withColumn("s", when($"s".isin(ar : _*), "Others").otherwise('s)).show
+---+------+
|  n|     s|
+---+------+
|  1|   ABC|
|  2|   ABC|
|  3|Others|
|  4|   FPK|
|  5|   FPK|
|  6|   ABC|
|  7|   ABC|
|  8|Others|
|  9|Others|
| 10|   FPK|
+---+------+

这称为重复参数。 cf here

答案 1 :(得分:3)

您不必经历所有麻烦事,您可以使用window函数获取每个组的counts并检查使用when/otherwise函数填充Others或不填充以下内容

val df = sc.parallelize(Seq(
  (1, "ABC"),
  (2, "ABC"),
  (3, "123"),
  (4, "FPK"),
  (5, "FPK"),
  (6, "ABC"),
  (7, "ABC"),
  (8, "980"),
  (9, "abc"),
  (10, "FPK")
)).toDF("n", "s")

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions._
df.withColumn("s", when(count("s").over(Window.partitionBy("s").orderBy("n").rowsBetween(Long.MinValue, Long.MaxValue)) > 1, col("s")).otherwise("Others")).show(false)

应该给您

+---+------+
|n  |s     |
+---+------+
|4  |FPK   |
|5  |FPK   |
|10 |FPK   |
|8  |Others|
|9  |Others|
|1  |ABC   |
|2  |ABC   |
|6  |ABC   |
|7  |ABC   |
|3  |Others|
+---+------+

我希望答案会有所帮助