一种热编码复合字段

时间:2018-06-30 21:37:42

标签: apache-spark pyspark one-hot-encoding multivalue

我想使用OneHotEncoder转换具有相同分类值的多列。我创建了一个复合字段,并尝试在其上使用OneHotEncoder,如下所示:(项目1-3来自相同的项目列表)

import pyspark.sql.functions as F

df = df.withColumn("basket", myConcat("item1", "item2", "item3")) 
indexer = StringIndexer(inputCol="basket", outputCol="basketIndex")
indexed = indexer.fit(df).transform(df)
encoder = OneHotEncoder(setInputCol="basketIndex", setOutputCol="basketVec")

encoded = encoder.transform(indexed)

def myConcat(*cols):
    return F.concat(*[F.coalesce(c, F.lit("*")) for c in cols])

我遇到了内存不足错误

这种方法行得通吗?如何使用同一列表中的分类值对组合字段或多列进行热编码?

2 个答案:

答案 0 :(得分:2)

如果您有分类值数组,为什么不尝试使用CountVectorizer:

import pyspark.sql.functions as F
from pyspark.ml.feature import CountVectorizer

df = df.withColumn("basket", myConcat("item1", "item2", "item3")) 
indexer = CountVectorizer(inputCol="basket", outputCol="basketIndex")
indexed = indexer.fit(df).transform(df)

答案 1 :(得分:0)

注意:由于我是新用户,我还不能发表评论。

“ item1”,“ item2”和“ item3”的基数是什么

更具体地说,以下印刷品给出的值是什么?

k1 = df.item1.nunique()
k2 = df.item2.nunique()
k3 = df.item3.nunique()
k = k1 * k2 * k3
print (k1, k2, k3)

一种热编码基本上是创建一个非常稀疏的矩阵,该矩阵的行数与原始数据帧的行数相同,其中k列为附加列,其中k =上面打印的三个数字的乘积。

因此,如果您的3个数字很大,则会出现内存不足错误。

唯一的解决方案是:

(1)增加您的记忆力或 (2)在类别之间引入层次结构,并使用更高级别的类别来限制k。