我尝试为以下输入数据创建一个热编码器:
+------+--------------------+
|userid| categoryIndexes|
+------+--------------------+
| 24868| [7276]|
| 35335| [12825]|
| 42634| . [14550, 14550]|
| 51183| [7570]|
| 61065| [14782]|
| 70292| [7282]|
| 72326| [14883, 14877]|
| 96632| [14902]|
| 99703| [14889]|
|121994| [16000, 7417]|
|144782| [12139, 12139]|
|175886| [7305, 7305]|
|221451| [14889, 12139]|
|226945| [18097]|
|250401| [7278]|
|256892| [7383, 5514]|
|270043| [7442]|
|272338| [7306]|
|284802| [18310, 14898]|
+------+--------------------+
参考Aggregating a One-Hot Encoded feature in pyspark和Encode and assemble multiple features in PySpark,我尝试用
解决from pyspark.ml.feature import CountVectorizer
df_user_catlist = df_order.groupBy("userid").agg(F.collect_list('level3_cat').alias('categoryIndexes'))
cv = CountVectorizer(inputCol='categoryIndexes', outputCol='categoryVec')
transformed_df = cv.fit(df_user_catlist).transform(df_user_catlist)
transformed_df.show()
但是发现了以下错误
IllegalArgumentException: u'requirement failed: Column category must be of type equal to one of the following types: [ArrayType(StringType,true), ArrayType(StringType,false)] but was actually of type ArrayType(IntegerType,true).'
我注意到区别是输入数据是IntegerType而不是StringType,我可能知道(a)如何将其转换为StringType,或者有更好的方法将其转换为OHE?
答案 0 :(得分:2)
您需要将字符串转换为类别索引:
from pyspark.sql import functions as F
df_user_catlist = df_user_catlist \
.withColumn('categoryIndexes',
F.col('categoryIndexes').cast('array<string>'))