带有MLLIB的pyspark数据帧中的点积

时间:2019-05-02 10:57:35

标签: python apache-spark pyspark apache-spark-mllib

我在pyspark中有一个非常简单的数据框,如下所示:

from pyspark.sql import Row
from pyspark.mllib.linalg import DenseVector

row = Row("a", "b")
df = spark.sparkContext.parallelize([
    offer_row(DenseVector([1, 1, 1]), DenseVector([1, 0, 0])),
]).toDF()

我想在不求助于UDF调用的情况下计算这些向量的点积。

spark MLLIB documentation引用了dot上的DenseVectors方法,但是如果我尝试按以下方法应用此方法:

df_offers = df_offers.withColumn("c", col("a").dot(col("b")))

我收到如下错误:

TypeError: 'Column' object is not callable

有人知道这些mllib方法是否可以在DataFrame对象上调用吗?

2 个答案:

答案 0 :(得分:0)

没有。您必须使用udf:

from pyspark.sql.functions import udf

@udf("double")
def dot(x, y):
    if x is not None and y is not None:
        return float(x.dot(y))

答案 1 :(得分:0)

在这里,您将dot方法应用于列而不是DenseVector上,这实际上是行不通的:

df_offers = df_offers.withColumn("c", col("a").dot(col("b")))

您将必须使用udf:

from pyspark.sql.functions import udf, array
from pyspark.sql.types import DoubleType

def dot_fun(array):
    return array[0].dot(array[1])

dot_udf = udf(dot_fun, DoubleType())

df_offers = df_offers.withColumn("c", dot_udf(array('a', 'b')))