我想为spark开发一个自定义估算器,它也可以处理大型管道API的持久性。但是正如How to Roll a Custom Estimator in PySpark mllib所说,那里还没有很多文件(<)。
我有一些用spark编写的数据清理代码,并希望将其包装在自定义估算器中。包括一些na替换,列删除,过滤和基本特征生成(例如,出生日期到年龄)。
ScalaReflection.schemaFor[MyClass].dataType.asInstanceOf[StructType]
我还不清楚:
transform
将用于转换&#34; fit&#34;关于新数据的估算器。它是否正确?如果是,我应该如何转移拟合值,例如从上面进入模型的平均年龄?
如何处理持久性?我在私有spark组件中找到了一些通用的loadImpl
方法,但我不确定如何传输我自己的参数,例如MLReader
/ MLWriter
中用于序列化的平均年龄。
如果你可以帮助我使用自定义估算器,那将会很棒 - 特别是对于持久性部分。
答案 0 :(得分:2)
首先,我相信你会混合两件不同的东西:
Estimators
- 代表可以fit
的阶段。 Estimator
fit
方法需要Dataset
并返回Transformer
(模型)。Transformers
- 代表可以transform
数据的阶段。当您fit
Pipeline
fits
全部Estimators
并返回PipelineModel
时。 PipelineModel
可以transform
数据在模型中的所有transform
上依次调用Transformers
。
我应该如何转移拟合值
这个问题没有一个答案。通常,您有两种选择:
Transformer
。Transformer
。Params
的参数
第一种方法通常由内置Transformer
使用,但第二种方法应该在一些简单的情况下使用。
如何处理持久性
Transformer
仅由其Params
定义,则可以延长DefaultParamsReadable
。MLWritable
并实施对您的数据有意义的MLWriter
。 Spark源代码中有多个示例,展示了如何实现数据和元数据的读/写。如果您正在寻找易于理解的示例,请查看CountVectorizer(Model)
其中:
Estimator
和Transformer
share common Params
。DefaultParamsWriter
/ DefaultParamsReader
的元数据(参数)为written和read。答案 1 :(得分:2)
以下使用 Scala API ,但如果您真的想要,可以轻松地将其重构为Python。
首先要做的事情:
.fit()
返回Transformer .transform()
并操纵DataFrame DefaultParamsWritable
特征 + 配套对象扩展DefaultParamsReadable[T]
。 a.k.a远离MLReader / MLWriter并保持代码简单。Params
的公共特征并在Estimator和Model(a.k.a.Transformer)之间共享骨架代码:
// Common Parameters
trait MyCommonParams extends Params {
final val inputCols: StringArrayParam = // usage: new MyMeanValueStuff().setInputCols(...)
new StringArrayParam(this, "inputCols", "doc...")
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
def getInputCols: Array[String] = $(inputCols)
final val meanValues: DoubleArrayParam =
new DoubleArrayParam(this, "meanValues", "doc...")
// more setters and getters
}
// Estimator
class MyMeanValueStuff(override val uid: String) extends Estimator[MyMeanValueStuffModel]
with DefaultParamsWritable // Enables Serialization of MyCommonParams
with MyCommonParams {
override def copy(extra: ParamMap): Estimator[MeanValueFillerModel] = defaultCopy(extra) // deafult
override def transformSchema(schema: StructType): StructType = schema // no changes
override def fit(dataset: Dataset[_]): MyMeanValueStuffModel = {
// your logic here. I can't do all the work for you! ;)
this.setMeanValues(meanValues)
copyValues(new MyMeanValueStuffModel(uid + "_model").setParent(this))
}
}
// Companion object enables deserialization of MyCommonParams
object MyMeanValueStuff extends DefaultParamsReadable[MyMeanValueStuff]
// Model (Transformer)
class MyMeanValueStuffModel(override val uid: String) extends Model[MyMeanValueStuffModel]
with DefaultParamsWritable // Enables Serialization of MyCommonParams
with MyCommonParams {
override def copy(extra: ParamMap): MyMeanValueStuffModel = defaultCopy(extra) // default
override def transformSchema(schema: StructType): StructType = schema // no changes
override def transform(dataset: Dataset[_]): DataFrame = {
// your logic here: zip inputCols and meanValues, toMap, replace nulls with NA functions
// you have access to both inputCols and meanValues here!
}
}
// Companion object enables deserialization of MyCommonParams
object MyMeanValueStuffModel extends DefaultParamsReadable[MyMeanValueStuffModel]
使用上面的代码,您可以序列化/反序列化包含MyMeanValueStuff阶段的管道。
想看一下Estimator的一些简单实现吗? MinMaxScaler! (我的例子实际上更简单......)