用户在Pyspark管道中定义了变换器

时间:2017-07-19 12:20:06

标签: python apache-spark machine-learning pyspark spark-dataframe

我正在尝试创建一个pyspark管道来运行分类模型。我的数据集有一个字符串列。所以我在使用管道中的模型之前使用'StringIndexer'将其转换为数字。

我的管道只包含2个阶段 StringIndexer ClassificationModel

StringIndexer正在创建一个带索引的新列,但是也会保留旧列。我想在管道中引入一个新的变换器来删除一个“字符串”列。这可能吗?

有没有其他方法可以删除 StringIndexer 中的实际列?

由于

1 个答案:

答案 0 :(得分:2)

是的,您可以扩展abstract class Transformer并创建自己的变换器,从而删除不必要的列。

这应该类似于以下内容:

import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.{
  ArrayType,
  StringType,
  StructField,
  StructType
}
import org.apache.spark.sql.functions.collect_list

class Dropper(override val uid: String) extends Transformer {

  def this() = this(Identifiable.randomUID("dropper"))

  override def transform(dataset: Dataset[_]): DataFrame = {
    dataset.drop("your-column-name-here")
  }

  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

  override def transformSchema(schema: StructType): StructType = {
    //here you should right your result schema i.e. the schema without the dropped column
  }

}

我已经做了一段时间了,这对我很有用。

请注意,您还可以扩展abstract class Estimator

希望它有所帮助。最诚挚的问候