scala中的RDD中的一个热编码

时间:2016-12-10 05:54:36

标签: scala apache-spark

我有来自movielense ml-100K数据集的用户数据。

示例行是 -

1|24|M|technician|85711
2|53|F|other|94043
3|23|M|writer|32067
4|24|M|technician|43537
5|33|F|other|15213

我已将数据读取为RDD,如下所示 -

scala> val user_data =  sc.textFile("/home/user/Documents/movielense/ml-100k/u.user").map(x=>x.split('|'))
user_data: org.apache.spark.rdd.RDD[Array[String]] = MapPartitionsRDD[5] at map at <console>:29

scala> user_data.take(5)
res0: Array[Array[String]] = Array(Array(1, 24, M, technician, 85711), Array(2, 53, F, other, 94043), Array(3, 23, M, writer, 32067), Array(4, 24, M, technician,    43537), Array(5, 33, F, other, 15213))


# encode distinct profession with zipWithIndex -
scala> val indexed_profession = user_data.map(x=>x(3)).distinct().sortBy[String](x=>x).zipWithIndex()
indexed_profession: org.apache.spark.rdd.RDD[(String, Long)] = ZippedWithIndexRDD[18] at zipWithIndex at <console>:31

scala> indexed_profession.collect()
res1: Array[(String, Long)] = Array((administrator,0), (artist,1), (doctor,2), (educator,3), (engineer,4), (entertainment,5), (executive,6), (healthcare,7),  (homemaker,8), (lawyer,9), (librarian,10), (marketing,11), (none,12), (other,13), (programmer,14), (retired,15), (salesman,16), (scientist,17), (student,18), (technician,19), (writer,20))

我想为Occupation专栏做一个热门编码。

预期输出是 -

 userId   Age  Gender  Occupation   Zipcodes technician  other  writer 
 1        24    M      technician   85711      1           0     0
 2        53    F      other        94043      0           1     0
 3        23    M      writer       32067      0           0     1
 4        24    M      technician   43537      1           0     0
 5        33    F      other        15213      0           1     0

如何在scala中的RDD上实现此目的。 我想在RDD上执行操作而不将其转换为数据帧。

任何帮助

由于

2 个答案:

答案 0 :(得分:0)

我是按照以下方式做到的 -

1)读取用户数据 -

scala> val user_data =  sc.textFile("/home/user/Documents/movielense/ml-100k/u.user").map(x=>x.split('|'))

2)显示5行数据 -

scala> user_data.take(5)
res0: Array[Array[String]] = Array(Array(1, 24, M, technician, 85711), Array(2, 53, F, other, 94043), Array(3, 23, M, writer, 32067), Array(4, 24, M, technician,    43537), Array(5, 33, F, other, 15213))

3)通过索引创建专业地图 -

scala> val indexed_profession = user_data.map(x=>x(3)).distinct().sortBy[String](x=>x).zipWithIndex().collectAsMap()

scala> indexed_profession
res35: scala.collection.Map[String,Long] = Map(scientist -> 17, writer -> 20, doctor -> 2, healthcare -> 7, administrator -> 0, educator -> 3, homemaker -> 8, none -> 12, artist -> 1, salesman -> 16, executive -> 6, programmer -> 14, engineer -> 4, librarian -> 10, technician -> 19, retired -> 15, entertainment -> 5, marketing -> 11, student -> 18, lawyer -> 9, other -> 13)

4)创建编码功能,执行一个专业的热编码

scala> def encode(x: String) =
 |{
 | var encodeArray = Array.fill(21)(0)
 | encodeArray(indexed_user.get(x).get.toInt)=1
 | encodeArray
 }

5)将编码功能应用于用户数据 -

scala> val encode_user_data = user_data.map{ x => (x(0),x(1),x(2),x(3),x(4),encode(x(3)))}

6)显示编码数据 -

scala> encode_user_data.take(6)
res71: Array[(String, String, String, String, String, Array[Int])] = 

1,24,M,technician,85711,Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0)), 
2,53,F,other,94043,Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0)), 
3,23,M,writer,32067,Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1)), 
4,24,M,technician,43537,Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0)), 
5,33,F,other,15213,Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0)), 
6,42,M,executive,98101,Array(0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)))

答案 1 :(得分:0)

[我的解决方案是针对Dataframe的]以下内容应有助于将分类映射转换为一键式。您必须创建一个地图catMap对象,其键为列名,值为类别列表。

    var OutputDf = df
        for (cat <- catMap.keys) {
          val categories = catMap(cat)
        for (oneHotVal <- categories) {
          OutputDf = OutputDf.withColumn(oneHotVal, 
            when(lower(OutputDf(cat)) === oneHotVal, 1).otherwise(0))
                                          }
                }
    OutputDf