是否可以在Spark Dataframe Column中存储numpy数组?

时间:2017-07-07 08:11:19

标签: numpy pyspark spark-dataframe

我有一个dataframe并且我将一个函数应用于它。此函数返回numpy array代码,如下所示:

create_vector_udf = udf(create_vector, ArrayType(FloatType()))
dataframe = dataframe.withColumn('vector', create_vector_udf('text'))
dmoz_spark_df.select('lang','url','vector').show(20)

现在火花似乎对此并不满意,并且不接受ArrayType(FloatType()) 我收到以下错误消息: net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for numpy.core.multiarray._reconstruct)

我可以numpyarray.tolist()并返回它的列表版本,但显然如果我想将其与array一起使用,我将永远重新创建numpy

那么有没有办法在numpy array dataframe中存储column

3 个答案:

答案 0 :(得分:0)

我还没有尝试过,但是也许您可以使用类似于spark_sklearn中的tqdm的UDT。

答案 1 :(得分:0)

问题的根源是从UDF返回的对象不符合声明的类型。 create_vector不仅必须返回numpy.ndarray,而且还必须将数字转换为与DataFrame API不兼容的相应NumPy类型。

唯一的选择是使用类似这样的东西:

udf(lambda x: create_vector(x).tolist(), ArrayType(FloatType()))

答案 2 :(得分:0)

一种方法是将DataFrame中numpy数组的每一行转换为整数列表。

df.col_2 = df.col_2.map(lambda x: [int(e) for e in x])

然后,将其直接转换为Spark DataFrame

df_spark = spark.createDataFrame(df)
df_spark.select('col_1', explode(col('col_2')).alias('col_2')).show(14)