过滤Vector类型的“features”列

时间:2017-07-06 18:54:51

标签: apache-spark apache-spark-sql

我正在开发一个程序,我需要根据某些条件在数据集中显示特定的行。这些条件适用于我为机器学习模型创建的features列。这个features列是一个Vector列,当我尝试通过传递Vector值来过滤它时,我收到以下错误:

Exception in thread "main" java.lang.RuntimeException: Unsupported literal type class org.apache.spark.ml.linalg.DenseVector at org.apache.spark.sql.catalyst.expressions.Literal$.apply(literals.scala:75) at org.apache.spark.sql.functions$.lit(functions.scala:101)

这是给我错误的过滤部分:

dataset.where(dataset.col("features").notEqual(datapoint)); //datapoint is a Vector

有什么方法吗?

2 个答案:

答案 0 :(得分:1)

您需要创建一个用于过滤Vector的udf。以下对我有用:

import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.functions.udf

val df = sc.parallelize(Seq(
  (1, 1, 1), (1, 2, 3), (1, 3, 5), (2, 4, 6),
  (2, 5, 2), (2, 6, 1), (3, 7, 5), (3, 8, 16),
  (1, 1, 1))).toDF("c1", "c2", "c3")

val dfVec =  new VectorAssembler()
  .setInputCols(Array("c1", "c2", "c3"))
  .setOutputCol("features")
  .transform(df)

def vectors_unequal(vec1: Vector) = udf((vec2: Vector) => !vec1.equals(vec2))

val vecToRemove = Vectors.dense(1,1,1)

val filtered = dfVec.where(vectors_unequal(vecToRemove)(dfVec.col("features")))  
val filtered2 = dfVec.filter(vectors_unequal(vecToRemove)($"features")) // Also possible

dfVec show收益:

+---+---+---+--------------+
| c1| c2| c3|      features|
+---+---+---+--------------+
|  1|  1|  1| [1.0,1.0,1.0]|
|  1|  2|  3| [1.0,2.0,3.0]|
|  1|  3|  5| [1.0,3.0,5.0]|
|  2|  4|  6| [2.0,4.0,6.0]|
|  2|  5|  2| [2.0,5.0,2.0]|
|  2|  6|  1| [2.0,6.0,1.0]|
|  3|  7|  5| [3.0,7.0,5.0]|
|  3|  8| 16|[3.0,8.0,16.0]|
|  1|  1|  1| [1.0,1.0,1.0]|
+---+---+---+--------------+

filtered show收益:

+---+---+---+--------------+
| c1| c2| c3|      features|
+---+---+---+--------------+
|  1|  2|  3| [1.0,2.0,3.0]|
|  1|  3|  5| [1.0,3.0,5.0]|
|  2|  4|  6| [2.0,4.0,6.0]|
|  2|  5|  2| [2.0,5.0,2.0]|
|  2|  6|  1| [2.0,6.0,1.0]|
|  3|  7|  5| [3.0,7.0,5.0]|
|  3|  8| 16|[3.0,8.0,16.0]|
+---+---+---+--------------+

答案 1 :(得分:0)

我的情况: word2vec 后的原始数据:

result.show(10,false)

+-------------+-----------------------------------------------------------------------------------------------------------+
|ip           |features                                                                                                   |
+-------------+-----------------------------------------------------------------------------------------------------------+
|1.1.125.120  |[0.0,0.0,0.0,0.0,0.0]                                                                                      |
|1.11.114.150 |[0.0,0.0,0.0,0.0,0.0]                                                                                      |
|1.116.114.36 |[0.022845590487122536,-0.012075710110366344,-0.034423209726810455,-0.04642726108431816,0.09164007753133774]|
|1.117.21.102 |[0.0,0.0,0.0,0.0,0.0]                                                                                      |
|1.119.13.5   |[0.0,0.0,0.0,0.0,0.0]                                                                                      |
|1.119.130.2  |[0.0,0.0,0.0,0.0,0.0]                                                                                      |
|1.119.132.162|[0.0,0.0,0.0,0.0,0.0]                                                                                      |
|1.119.133.166|[0.0,0.0,0.0,0.0,0.0]                                                                                      |
|1.119.136.170|[0.0,0.0,0.0,0.0,0.0]                                                                                      |
|1.119.137.154|[0.0,0.0,0.0,0.0,0.0]                                                                                      |
+-------------+-----------------------------------------------------------------------------------------------------------+

我想删除嵌入零的ip:

import org.apache.spark.sql.functions.udf
import org.apache.spark.ml.linalg.Vector

val vecToSeq = udf((v: Vector) => v.toArray).asNondeterministic
val output = result.select($"ip",vecToSeq($"features").alias("features"))

val select_output = output.filter(output("features")!==Array(0,0,0,0,0))
select_output.show(5)


+-------------+--------------------+
|           ip|            features|
+-------------+--------------------+
| 1.116.114.36|[0.02284559048712...|
| 1.119.137.98|[-0.0244039318391...|
|1.119.177.102|[-0.0801128149032...|
|1.119.186.170|[0.01125990878790...|
|1.119.193.226|[0.04201301932334...|
+-------------+--------------------+