创建一站式编码器。 CountVectorizer返回ArrayType(IntergerType,true)的错误

时间:2018-09-25 11:24:23

标签: apache-spark pyspark apache-spark-sql apache-spark-mllib one-hot-encoding

我尝试为以下输入数据创建一个热编码器:

+------+--------------------+
|userid|     categoryIndexes|
+------+--------------------+
| 24868|              [7276]|
| 35335|             [12825]|
| 42634| .    [14550, 14550]|
| 51183|              [7570]|
| 61065|             [14782]|
| 70292|              [7282]|
| 72326|      [14883, 14877]|
| 96632|             [14902]|
| 99703|             [14889]|
|121994|       [16000, 7417]|
|144782|      [12139, 12139]|
|175886|        [7305, 7305]|
|221451|      [14889, 12139]|
|226945|             [18097]|
|250401|              [7278]|
|256892|        [7383, 5514]|
|270043|              [7442]|
|272338|              [7306]|
|284802|      [18310, 14898]|
+------+--------------------+

参考Aggregating a One-Hot Encoded feature in pysparkEncode and assemble multiple features in PySpark,我尝试用

解决    
from pyspark.ml.feature import CountVectorizer

df_user_catlist = df_order.groupBy("userid").agg(F.collect_list('level3_cat').alias('categoryIndexes'))
cv = CountVectorizer(inputCol='categoryIndexes', outputCol='categoryVec')
transformed_df = cv.fit(df_user_catlist).transform(df_user_catlist)
transformed_df.show()

但是发现了以下错误

IllegalArgumentException: u'requirement failed: Column category must be of type equal to one of the following types: [ArrayType(StringType,true), ArrayType(StringType,false)] but was actually of type ArrayType(IntegerType,true).'

我注意到区别是输入数据是IntegerType而不是StringType,我可能知道(a)如何将其转换为StringType,或者有更好的方法将其转换为OHE?

1 个答案:

答案 0 :(得分:2)

您需要将字符串转换为类别索引:

from pyspark.sql import functions as F

df_user_catlist = df_user_catlist \
    .withColumn('categoryIndexes', 
         F.col('categoryIndexes').cast('array<string>'))