我有一个功能列,使用Spark的VectorAssembler打包成Vector向量,如下所示。 data
是输入DataFrame(类型为spark.sql.DataFrame
)。
val featureCols = Array("feature_1","feature_2","feature_3")
val featureAssembler = new VectorAssembler().setInputCols(featureCols).setOutputCol("features")
val dataWithFeatures = featureAssembler.transform(data)
我正在使用Classifier
和ClassificationModel
开发人员API开发自定义分类器。 ClassificationModel
需要开发predictRaw()
函数,该函数从模型中输出预测标签的向量。
def predictRaw(features: FeaturesType) : Vector
此功能由API设置并采用参数,FeaturesType
的功能并输出一个Vector(在我的情况下,我将DenseVector
视为DenseVector
{{1}扩展Vector
特征。
由于VectorAssembler的打包,features
列的类型为Vector
,每个元素本身就是每个训练样本的原始要素的向量。例如:
功能列 - 类型为矢量
[1.0,2.0,3.0] - element1,本身是一个矢量
[3.5,4.5,5.5] - element2,本身是一个向量
我需要将这些功能提取到Array[Double]
中,以实现我的predictRaw()
逻辑。理想情况下,我希望得到以下结果,以保留基数:
`val result: Array[Double] = Array(1.0, 3.5, 2.0, 4.5, 3.0, 4.5)`
即。按照列主要顺序,我将把它变成一个矩阵。
我试过了:
val array = features.toArray // this gives an array of vectors and doesn't work
由于VectorAssembler的功能包装,我还试图将这些功能作为DataFrame对象输入而不是Vector,但API期待Vector。例如,此函数本身有效,但不符合API,因为它期望FeaturesType为Vector而不是DataFrame:
def predictRaw(features: DataFrame) :DenseVector = {
val featuresArray: Array[Double] = features.rdd.map(r => r.getAs[Vector](0).toArray).collect
//rest of logic would go here
}
我的问题是features
的类型为Vector
,而不是DataFrame
。另一个选项可能是将features
打包为DataFrame
,但我不知道如何在不使用VectorAssembler
的情况下执行此操作。
所有建议都表示赞赏,谢谢!我查看了Access element of a vector in a Spark DataFrame (Logistic Regression probability vector),但这是在python中,我使用的是Scala。
答案 0 :(得分:2)
如果您只是想将DenseVector转换为Array [Double],这对于UDF来说相当简单:
import org.apache.spark.ml.linalg.DenseVector
val toArr: Any => Array[Double] = _.asInstanceOf[DenseVector].toArray
val toArrUdf = udf(toArr)
val dataWithFeaturesArr = dataWithFeatures.withColumn("features_arr",toArrUdf('features))
这将为您提供一个新列:
|-- features_arr: array (nullable = true)
| |-- element: double (containsNull = false)
答案 1 :(得分:1)
这是从Dataframe(String,Vector)获取Datagrame(String,Array)的方法(没有udf)。主要思想是使用中间RDD转换为Vector,并使用其toArray方法:
val arrayDF = vectorDF.rdd
.map(x => x.getAs[String](0) -> x.getAs[Vector](1).toArray)
.toDF("word","array")
答案 2 :(得分:0)
Spark 3.0添加了vector_to_array UDF。无需自己实施https://github.com/apache/spark/pull/26910
import org.apache.spark.ml.linalg.{SparseVector, Vector}
import org.apache.spark.mllib.linalg.{Vector => OldVector}
private val vectorToArrayUdf = udf { vec: Any =>
vec match {
case v: Vector => v.toArray
case v: OldVector => v.toArray
case v => throw new IllegalArgumentException(
"function vector_to_array requires a non-null input argument and input type must be " +
"`org.apache.spark.ml.linalg.Vector` or `org.apache.spark.mllib.linalg.Vector`, " +
s"but got ${ if (v == null) "null" else v.getClass.getName }.")
}
}.asNonNullable()
答案 3 :(得分:0)
我的情况:word2vec 后的原始数据:
result.show(10,false)
+-------------+-----------------------------------------------------------------------------------------------------------+
|ip |features |
+-------------+-----------------------------------------------------------------------------------------------------------+
|1.1.125.120 |[0.0,0.0,0.0,0.0,0.0] |
|1.11.114.150 |[0.0,0.0,0.0,0.0,0.0] |
|1.116.114.36 |[0.022845590487122536,-0.012075710110366344,-0.034423209726810455,-0.04642726108431816,0.09164007753133774]|
|1.117.21.102 |[0.0,0.0,0.0,0.0,0.0] |
|1.119.13.5 |[0.0,0.0,0.0,0.0,0.0] |
|1.119.130.2 |[0.0,0.0,0.0,0.0,0.0] |
|1.119.132.162|[0.0,0.0,0.0,0.0,0.0] |
|1.119.133.166|[0.0,0.0,0.0,0.0,0.0] |
|1.119.136.170|[0.0,0.0,0.0,0.0,0.0] |
|1.119.137.154|[0.0,0.0,0.0,0.0,0.0] |
+-------------+-----------------------------------------------------------------------------------------------------------+
我想删除嵌入零的ip:
import org.apache.spark.sql.functions.udf
import org.apache.spark.ml.linalg.Vector
val vecToSeq = udf((v: Vector) => v.toArray).asNondeterministic
val output = result.select($"ip",vecToSeq($"features").alias("features"))
val select_output = output.filter(output("features")!==Array(0,0,0,0,0))
select_output.show(5)
+-------------+--------------------+
| ip| features|
+-------------+--------------------+
| 1.116.114.36|[0.02284559048712...|
| 1.119.137.98|[-0.0244039318391...|
|1.119.177.102|[-0.0801128149032...|
|1.119.186.170|[0.01125990878790...|
|1.119.193.226|[0.04201301932334...|
+-------------+--------------------+