Spark-ML自定义模型,变形金刚

时间:2017-06-29 18:36:37

标签: scala apache-spark-ml

这是在Spark 2.0.1上

我正在尝试编译并使用here中的SimpleIndexer示例。

import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.ml._

import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._

trait SimpleIndexerParams extends Params {
  final val inputCol= new Param[String](this, "inputCol", "The input column")
  final val outputCol = new Param[String](this, "outputCol", "The output column")
}

class SimpleIndexer(override val uid: String) extends Estimator[SimpleIndexerModel] with SimpleIndexerParams {

  def setInputCol(value: String) = set(inputCol, value)

  def setOutputCol(value: String) = set(outputCol, value)

  def this() = this(Identifiable.randomUID("simpleindexer"))

  override def copy(extra: ParamMap): SimpleIndexer = {
    defaultCopy(extra)
  }

  override def transformSchema(schema: StructType): StructType = {
    // Check that the input type is a string
    val idx = schema.fieldIndex($(inputCol))
    val field = schema.fields(idx)
    if (field.dataType != StringType) {
      throw new Exception(s"Input type ${field.dataType} did not match input type StringType")
    }
    // Add the return field
    schema.add(StructField($(outputCol), IntegerType, false))
  }

  override def fit(dataset: Dataset[_]): SimpleIndexerModel = {
    import dataset.sparkSession.implicits._
    val words = dataset.select(dataset($(inputCol)).as[String]).distinct
      .collect()
    new SimpleIndexerModel(uid, words)
 ; }
}

class SimpleIndexerModel(
  override val uid: String, words: Array[String]) extends Model[SimpleIndexerModel] with SimpleIndexerParams {

  override def copy(extra: ParamMap): SimpleIndexerModel = {
    defaultCopy(extra)
  }

  private val labelToIndex: Map[String, Double] = words.zipWithIndex.
    map{case (x, y) => (x, y.toDouble)}.toMap

  override def transformSchema(schema: StructType): StructType = {
    // Check that the input type is a string
    val idx = schema.fieldIndex($(inputCol))
    val field = schema.fields(idx)
    if (field.dataType != StringType) {
      throw new Exception(s"Input type ${field.dataType} did not match input type StringType")
    }
    // Add the return field
    schema.add(StructField($(outputCol), IntegerType, false))
  }

  override def transform(dataset: Dataset[_]): DataFrame = {
    val indexer = udf { label: String => labelToIndex(label) }
    dataset.select(col("*"),
      indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol)))
  }
}

但是,我在转换过程中遇到错误:

val df = Seq(
  (10, "hello"),
  (20, "World"),
  (30, "goodbye"),
  (40, "sky")
).toDF("id", "phrase")

val si = new SimpleIndexer().setInputCol("phrase").setOutputCol("phrase_idx").fit(df)

si.transform(df).show(false)

// java.util.NoSuchElementException: Failed to find a default value for inputCol

知道怎么解决吗?

2 个答案:

答案 0 :(得分:1)

SimpleIndexer转换方法似乎接受数据集作为参数 - 而不是DataFrame(这是您传入的内容)。

case class Phrase(id: Int, phrase:String)
si.transform(df.as[Phrase])....

有关详细信息,请参阅文档:https://spark.apache.org/docs/2.0.1/sql-programming-guide.html

编辑: 问题似乎是SimpleIndexerModel无法通过表达式$(inputCol)访问“短语”列。我认为这是因为它在SimpleIndexer类中设置(并且上面的表达式工作正常)但在SimpleIndexerModel中无法访问。

一种解决方案是手动设置列名:

indexer(dataset.col("phrase").cast(StringType)).as("phrase_idx"))

但是在实例化SimpleIndexerModel时传入col名称可能更好:

class SimpleIndexerModel(override val uid: String, words: Array[String], inputColName: String, outputColName: String)
....

new SimpleIndexerModel(uid, words, $(inputCol), $(outputCol))

结果:

+---+-------+----------+
|id |phrase |phrase_idx|
+---+-------+----------+
|10 |hello  |1.0       |
|20 |World  |0.0       |
|30 |goodbye|3.0       |
|40 |sky    |2.0       |
+---+-------+----------+

答案 1 :(得分:0)

好的,我通过进入CountVectorizer的源代码想出来了。看起来我需要将new SimpleIndexerModel(uid, words)替换为copyValues(new SimpleIndexerModel(uid, words).setParent(this))。因此,新的fit方法变为

  override def fit(dataset: Dataset[_]): SimpleIndexerModel = {
    import dataset.sparkSession.implicits._
    val words = dataset.select(dataset($(inputCol)).as[String]).distinct
      .collect()
    //new SimpleIndexerModel(uid, words)
    copyValues(new SimpleIndexerModel(uid, words).setParent(this))
  }

通过这种方式,可以识别params,并且整齐地进行转换。

val si = new SimpleIndexer().setInputCol("phrase").setOutputCol("phrase_idx").fit(df)

si.explainParams
// res3: String =
// inputCol: The input column (current: phrase)
// outputCol: The output column (current: phrase_idx)

si.transform(df).show(false)
// +---+-------+----------+
// |id |phrase |phrase_idx|
// +---+-------+----------+
// |10 |hello  |1.0       |
// |20 |World  |0.0       |
// |30 |goodbye|3.0       |
// |40 |sky    |2.0       |
// +---+-------+----------+