如何将列转换为矢量类型?

时间:2016-03-18 01:23:06

标签: scala apache-spark apache-spark-ml

我在Spark中有一个RDD,其中的对象基于案例类:

ExampleCaseClass(user: User, stuff: Stuff)

我想使用Spark的ML管道,所以我将其转换为Spark数据帧。作为管道的一部分,我想将其中一列转换为条目为向量的列。由于我希望该向量的长度随模型而变化,因此应将其作为特征转换的一部分构建到管道中。

所以我试图按如下方式定义Transformer:

class MyTransformer extends Transformer {

  val uid = ""
  val num: IntParam = new IntParam(this, "", "")

  def setNum(value: Int): this.type = set(num, value)
  setDefault(num -> 50)

  def transform(df: DataFrame): DataFrame = {
    ...
  }

  def transformSchema(schema: StructType): StructType = {
    val inputFields = schema.fields
    StructType(inputFields :+ StructField("colName", ???, true))
  }

  def copy (extra: ParamMap): Transformer = defaultCopy(extra)

}

如何指定结果字段的DataType(即填写???)?它将是一个简单类的Vector(Boolean,Int,Double等)。看起来VectorUDT可能有用,但这对Spark是私有的。由于任何RDD都可以转换为DataFrame,因此任何案例类都可以转换为自定义DataType。但是,我无法弄清楚如何手动执行此转换,否则我可以将它应用于包装矢量的一些简单案例类。

此外,如果我为列指定了矢量类型,当我适应模型时,VectorAssembler会将矢量正确地处理成单独的特征吗?

仍然是Spark的新手,尤其是ML Pipeline,所以感谢任何建议。

2 个答案:

答案 0 :(得分:4)

import org.apache.spark.ml.linalg.SQLDataTypes.VectorType  
def transformSchema(schema: StructType): StructType = {
  val inputFields = schema.fields
  StructType(inputFields :+ StructField("colName", VectorType, true))
}

在spark 2.1中,VectorType使VectorUDT公开可用:

package org.apache.spark.ml.linalg

import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.sql.types.DataType

/**
 * :: DeveloperApi ::
 * SQL data types for vectors and matrices.
 */
@Since("2.0.0")
@DeveloperApi
object SQLDataTypes {

  /** Data type for [[Vector]]. */
  val VectorType: DataType = new VectorUDT

  /** Data type for [[Matrix]]. */
  val MatrixType: DataType = new MatrixUDT
}

答案 1 :(得分:3)

import org.apache.spark.mllib.linalg.{Vector, Vectors}

case class MyVector(vector: Vector)
val vectorDF = Seq(
  MyVector(Vectors.dense(1.0,3.4,4.4)),
  MyVector(Vectors.dense(5.5,6.7))
).toDF

vectorDF.printSchema
root
 |-- vector: vector (nullable = true)

println(vectorDF.schema.fields(0).dataType.prettyJson)
{
  "type" : "udt",
  "class" : "org.apache.spark.mllib.linalg.VectorUDT",
  "pyClass" : "pyspark.mllib.linalg.VectorUDT",
  "sqlType" : {
    "type" : "struct",
    "fields" : [ {
      "name" : "type",
      "type" : "byte",
      "nullable" : false,
      "metadata" : { }
    }, {
      "name" : "size",
      "type" : "integer",
      "nullable" : true,
      "metadata" : { }
    }, {
      "name" : "indices",
      "type" : {
        "type" : "array",
        "elementType" : "integer",
        "containsNull" : false
      },
      "nullable" : true,
      "metadata" : { }
    }, {
      "name" : "values",
      "type" : {
        "type" : "array",
        "elementType" : "double",
        "containsNull" : false
      },
      "nullable" : true,
      "metadata" : { }
    } ]
  }
}