我有一个libsvm格式的列(spark的ml库)field1:value field2:value ...
+--------------+-----+
| features|label|
+--------------+-----+
| a:1 b:2 c:3| 0|
| a:4 b:5 c:6| 0|
| a:7 b:8 c:9| 1|
|a:10 b:11 c:12| 0|
+--------------+-----+
我想提取值并将它们保存在pyspark中每一行的数组中
features.printSchema()
root
|-- features: string (nullable = false)
|-- label: integer (nullable = true)
我正在使用以下udf,因为受影响的列是数据框的一部分
from pyspark.sql.functions import udf
from pyspark.ml.linalg import Vectors
features_expl = udf(lambda features: Vectors.dense(features.split(" ")).map(lambda feat: float(str(feat.split(":")[1]))))
features=features.withColumn("feats", features_expl(features.features))
我得到的结果是: ValueError:无法将字符串转换为float:mobile:0.0 似乎它不执行第二次拆分,而是对字符串调用float()。
我想得到的是:
+--------------+-----+
| features|label|
+--------------+-----+
| [1, 2, 3]| 0|
| [4, 5, 6]| 0|
| [7, 8, 9]| 1|
| [10, 11, 12]| 0|
+--------------+-----+
答案 0 :(得分:0)
您的udf
有两个主要问题。首先,它无法按预期工作。将代码的核心视为以下功能:
from pyspark.ml.linalg import Vectors
def features_expl_non_udf(features):
return Vectors.dense(
features.split(" ")).map(lambda feat: float(str(feat.split(":")[1]))
)
如果您使用自己的一个字符串调用它:
features_expl_non_udf("a:1 b:2 c:3")
#ValueError: could not convert string to float: a:1
因为features.split(" ")
返回['a:1', 'b:2', 'c:3']
,您将其传递给Vectors.dense
构造函数。这没有任何意义。
您打算做的是首先在空间上分割,然后在:
上分割结果列表的每个值。然后,您可以将这些值转换为float
并将列表传递给Vectors.dense
。
这是您的逻辑的正确实现:
def features_expl_non_udf(features):
return Vectors.dense(map(lambda feat: float(feat.split(":")[1]), features.split()))
features_expl_non_udf("a:1 b:2 c:3")
#DenseVector([1.0, 2.0, 3.0])
现在udf
的第二个问题是您没有指定returnType
。对于DenseVector
,您需要use VectorUDT
as the returnType
。
from pyspark.sql.functions import udf
from pyspark.ml.linalg import VectorUDT
features_expl = udf(
lambda features: Vectors.dense(
map(lambda feat: float(feat.split(":")[1]), features.split())
),
VectorUDT()
)
features.withColumn("feats", features_expl(features.features)).show()
#+--------------+-----+----------------+
#| features|label| feats|
#+--------------+-----+----------------+
#| a:1 b:2 c:3| 0| [1.0,2.0,3.0]|
#| a:4 b:5 c:6| 0| [4.0,5.0,6.0]|
#| a:7 b:8 c:9| 1| [7.0,8.0,9.0]|
#|a:10 b:11 c:12| 0|[10.0,11.0,12.0]|
#+--------------+-----+----------------+
作为替代方案,您可以使用regexp_replace
和split
在火花端进行字符串处理,但是仍然需要使用udf
将最终输出转换为DenseVector
。
from pyspark.sql.functions import regexp_replace, split, udf
from pyspark.ml.linalg import Vectors, VectorUDT
toDenseVector = udf(Vectors.dense, VectorUDT())
features.withColumn(
"features",
toDenseVector(
split(regexp_replace("features", r"\w+:", ""), "\s+").cast("array<float>")
)
).show()
#+----------------+-----+
#| features|label|
#+----------------+-----+
#| [1.0,2.0,3.0]| 0|
#| [4.0,5.0,6.0]| 0|
#| [7.0,8.0,9.0]| 1|
#|[10.0,11.0,12.0]| 0|
#+----------------+-----+