我是Spark的相对初学者。我有一个宽的数据框(1000列),我想根据相应的列是否缺少值来添加列
所以
+----+ | A | +----+ | 1 | +----+ |null| +----+ | 3 | +----+
变为
+----+-------+ | A | A_MIS | +----+-------+ | 1 | 0 | +----+-------+ |null| 1 | +----+-------+ | 3 | 1 | +----+-------+
这是自定义ml变换器的一部分,但算法应该清楚。
Logo in /
在列上循环,如果> 0 nulls创建一个新列。
传入的数据集被缓存(使用.cache方法),相关的配置设置是默认值。 现在,它在一台笔记本电脑上运行,即使行数最少,也会以1000分钟的速度运行1000列。 我认为这个问题是由于访问数据库,所以我尝试使用镶木地板文件而不是相同的结果。查看作业UI,它似乎正在进行文件扫描以进行计数。
有没有办法可以改进此算法以获得更好的性能,或以某种方式调整缓存?增加spark.sql.inMemoryColumnarStorage.batchSize只是给我一个OOM错误。
答案 0 :(得分:1)
删除条件:
if (dataset.filter(col(c).isNull).count() > 0)
并只保留内部表达式。正如它所写,Spark需要#columns数据扫描。
如果您希望修剪列计算统计信息一次,如Count number of non-NaN entries in each column of Spark dataframe with Pyspark中所述,并使用单drop
次调用。
答案 1 :(得分:0)
以下是修复问题的代码。
override def transform(dataset: Dataset[_]): DataFrame = {
var ds = dataset
val rowCount = dataset.count()
val exprs = dataset.columns.map(count(_))
val colCounts = dataset.agg(exprs.head, exprs.tail: _*).toDF(dataset.columns: _*).first()
dataset.columns.foreach(c => {
if (colCounts.getAs[Long](c) > 0 && colCounts.getAs[Long](c) < rowCount ) {
ds = ds.withColumn(c + "_MIS", when(col(c).isNull, 1).otherwise(0))
}
})
ds.toDF()
}