作为一些背景知识,我试图在Spark中实现Kaplan-Meier。特别是,我假设我有一个数据框/集,Double
列表示为Data
,Int
列名为censorFlag
(0
值,如果审查,1
如果没有,则优先于Boolean
类型。
示例:
val df = Seq((1.0, 1), (2.3, 0), (4.5, 1), (0.8, 1), (0.7, 0), (4.0, 1), (0.8, 1)).toDF("data", "censorFlag").as[(Double, Int)]
现在我需要计算一列wins
来计算每个data
值的实例。我通过以下代码实现了这一点:
val distDF = df.withColumn("wins", sum(col("censorFlag")).over(Window.partitionBy("data").orderBy("data")))
当我需要计算一个名为atRisk
的数量时会出现问题,该数量会为data
的每个值计算大于或等于它的data
个点数(累积过滤计数,如果你愿意的话。
以下代码有效:
// We perform the counts per value of "bins". This is an array of doubles
val bins = df.select(col("data").as("dataBins")).distinct().sort("dataBins").as[Double].collect
val atRiskCounts = bins.map(x => (x, df.filter(col("data").geq(x)).count)).toSeq.toDF("data", "atRisk")
// this works:
atRiskCounts.show
但是,用例涉及从列bins
本身中派生data
,而我将其留作单列数据集(或RDD at最糟糕的),但肯定不是本地阵列。但这不起作用:
// Here, 'bins' rightfully come from the data itself.
val bins = df.select(col("data").as("dataBins")).distinct().as[Double]
val atRiskCounts = bins.map(x => (x, df.filter(col("data").geq(x)).count)).toSeq.toDF("data", "atRisk")
// This doesn't work -- NullPointerException
atRiskCounts.show
这也不是:
// Manually creating the bins and then parallelizing them.
val bins = Seq(0.7, 0.8, 1.0, 3.0).toDS
val atRiskCounts = bins.map(x => (x, df.filter(col("data").geq(x)).count)).toDF("data", "atRisk")
// Also fails with a NullPointerException
atRiskCounts.show
工作的另一种方法,但从并行化角度来看也不令人满意的是使用Window
:
// Do the counts in one fell swoop using a giant window per value.
val atRiskCounts = df.withColumn("atRisk", count("censorFlag").over(Window.orderBy("data").rowsBetween(0, Window.unboundedFollowing))).groupBy("data").agg(first("atRisk").as("atRisk"))
// Works, BUT, we get a "WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation."
atRiskCounts.show
这最后一个解决方案并不实用,因为它最终将我的数据混合到一个分区(在这种情况下,我可能会选择使用选项1)。
成功的方法很好,只是箱子不平行,这是我真的想保留的东西。我查看过groupBy
个聚合,pivot
类型的聚合,但似乎没有任何意义。
我的问题是:有没有办法以分布式方式计算atRisk
列?另外,为什么我在失败的解决方案中得到NullPointerException
?
编辑评论 :
我最初没有发布NullPointerException
,因为它似乎没有包含任何有用的东西。我将在Macbook Pro(Spark版本2.2.1,独立的本地主机模式)上通过自制软件安装Spark。
18/03/12 11:41:00 ERROR ExecutorClassLoader: Failed to check existence of class <root>.package on REPL class server at spark://10.37.109.111:53360/classes
java.net.URISyntaxException: Illegal character in path at index 36: spark://10.37.109.111:53360/classes/<root>/package.class
at java.net.URI$Parser.fail(URI.java:2848)
at java.net.URI$Parser.checkChars(URI.java:3021)
at java.net.URI$Parser.parseHierarchical(URI.java:3105)
at java.net.URI$Parser.parse(URI.java:3053)
at java.net.URI.<init>(URI.java:588)
at org.apache.spark.rpc.netty.NettyRpcEnv.openChannel(NettyRpcEnv.scala:327)
at org.apache.spark.repl.ExecutorClassLoader.org$apache$spark$repl$ExecutorClassLoader$$getClassFileInputStreamFromSparkRPC(ExecutorClassLoader.scala:90)
at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
at org.apache.spark.repl.ExecutorClassLoader.findClassLocally(ExecutorClassLoader.scala:162)
at org.apache.spark.repl.ExecutorClassLoader.findClass(ExecutorClassLoader.scala:80)
at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
. . . .
18/03/12 11:41:00 ERROR ExecutorClassLoader: Failed to check existence of class <root>.scala on REPL class server at spark://10.37.109.111:53360/classes
java.net.URISyntaxException: Illegal character in path at index 36: spark://10.37.109.111:53360/classes/<root>/scala.class
at java.net.URI$Parser.fail(URI.java:2848)
at java.net.URI$Parser.checkChars(URI.java:3021)
at java.net.URI$Parser.parseHierarchical(URI.java:3105)
at java.net.URI$Parser.parse(URI.java:3053)
at java.net.URI.<init>(URI.java:588)
at org.apache.spark.rpc.netty.NettyRpcEnv.openChannel(NettyRpcEnv.scala:327)
at org.apache.spark.repl.ExecutorClassLoader.org$apache$spark$repl$ExecutorClassLoader$$getClassFileInputStreamFromSparkRPC(ExecutorClassLoader.scala:90)
at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
at org.apache.spark.repl.ExecutorClassLoader.findClassLocally(ExecutorClassLoader.scala:162)
at org.apache.spark.repl.ExecutorClassLoader.findClass(ExecutorClassLoader.scala:80)
at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
. . .
18/03/12 11:41:00 ERROR ExecutorClassLoader: Failed to check existence of class <root>.org on REPL class server at spark://10.37.109.111:53360/classes
java.net.URISyntaxException: Illegal character in path at index 36: spark://10.37.109.111:53360/classes/<root>/org.class
at java.net.URI$Parser.fail(URI.java:2848)
at java.net.URI$Parser.checkChars(URI.java:3021)
at java.net.URI$Parser.parseHierarchical(URI.java:3105)
at java.net.URI$Parser.parse(URI.java:3053)
at java.net.URI.<init>(URI.java:588)
at org.apache.spark.rpc.netty.NettyRpcEnv.openChannel(NettyRpcEnv.scala:327)
at org.apache.spark.repl.ExecutorClassLoader.org$apache$spark$repl$ExecutorClassLoader$$getClassFileInputStreamFromSparkRPC(ExecutorClassLoader.scala:90)
at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
at org.apache.spark.repl.ExecutorClassLoader.findClassLocally(ExecutorClassLoader.scala:162)
at org.apache.spark.repl.ExecutorClassLoader.findClass(ExecutorClassLoader.scala:80)
at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
. . .
18/03/12 11:41:00 ERROR ExecutorClassLoader: Failed to check existence of class <root>.java on REPL class server at spark://10.37.109.111:53360/classes
java.net.URISyntaxException: Illegal character in path at index 36: spark://10.37.109.111:53360/classes/<root>/java.class
at java.net.URI$Parser.fail(URI.java:2848)
at java.net.URI$Parser.checkChars(URI.java:3021)
at java.net.URI$Parser.parseHierarchical(URI.java:3105)
at java.net.URI$Parser.parse(URI.java:3053)
at java.net.URI.<init>(URI.java:588)
at org.apache.spark.rpc.netty.NettyRpcEnv.openChannel(NettyRpcEnv.scala:327)
at org.apache.spark.repl.ExecutorClassLoader.org$apache$spark$repl$ExecutorClassLoader$$getClassFileInputStreamFromSparkRPC(ExecutorClassLoader.scala:90)
at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
at org.apache.spark.repl.ExecutorClassLoader.findClassLocally(ExecutorClassLoader.scala:162)
at org.apache.spark.repl.ExecutorClassLoader.findClass(ExecutorClassLoader.scala:80)
at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
. . .
18/03/12 11:41:00 ERROR Executor: Exception in task 0.0 in stage 55.0 (TID 432)
java.lang.NullPointerException
at org.apache.spark.sql.Dataset.<init>(Dataset.scala:171)
at org.apache.spark.sql.Dataset$.apply(Dataset.scala:62)
at org.apache.spark.sql.Dataset.withTypedPlan(Dataset.scala:2889)
at org.apache.spark.sql.Dataset.filter(Dataset.scala:1301)
at $line124.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:33)
at $line124.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:33)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:395)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:234)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:228)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
at org.apache.spark.scheduler.Task.run(Task.scala:108)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:338)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
at java.lang.Thread.run(Thread.java:748)
18/03/12 11:41:00 WARN TaskSetManager: Lost task 0.0 in stage 55.0 (TID 432, localhost, executor driver): java.lang.NullPointerException
at org.apache.spark.sql.Dataset.<init>(Dataset.scala:171)
at org.apache.spark.sql.Dataset$.apply(Dataset.scala:62)
at org.apache.spark.sql.Dataset.withTypedPlan(Dataset.scala:2889)
at org.apache.spark.sql.Dataset.filter(Dataset.scala:1301)
at $line124.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:33)
at $line124.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:33)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:395)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:234)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:228)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
at org.apache.spark.scheduler.Task.run(Task.scala:108)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:338)
18/03/12 11:41:00 ERROR TaskSetManager: Task 0 in stage 55.0 failed 1 times; aborting job
org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 55.0 failed 1 times, most recent failure: Lost task 0.0 in stage 55.0 (TID 432, localhost, executor driver): java.lang.NullPointerException
at org.apache.spark.sql.Dataset.<init>(Dataset.scala:171)
at org.apache.spark.sql.Dataset$.apply(Dataset.scala:62)
at org.apache.spark.sql.Dataset.withTypedPlan(Dataset.scala:2889)
at org.apache.spark.sql.Dataset.filter(Dataset.scala:1301)
at $anonfun$1.apply(<console>:33)
at $anonfun$1.apply(<console>:33)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:395)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:234)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:228)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
at org.apache.spark.scheduler.Task.run(Task.scala:108)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:338)
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1517)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1505)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1504)
at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
... 50 elided
Caused by: java.lang.NullPointerException
at org.apache.spark.sql.Dataset.<init>(Dataset.scala:171)
at org.apache.spark.sql.Dataset$.apply(Dataset.scala:62)
at org.apache.spark.sql.Dataset.withTypedPlan(Dataset.scala:2889)
at org.apache.spark.sql.Dataset.filter(Dataset.scala:1301)
at $anonfun$1.apply(<console>:33)
at $anonfun$1.apply(<console>:33)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:395)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:234)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:228)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
at org.apache.spark.scheduler.Task.run(Task.scala:108)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:338)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
at java.lang.Thread.run(Thread.java:748)
我最好的猜测是,df("data").geq(x).count
行可能是barfs的一部分,因为并非每个节点都有x
因此是空指针?
答案 0 :(得分:1)
我没有测试过这个,所以语法可能很傻,但我会做一系列的连接:
我相信您的第一个陈述与此相同 - 对于每个data
值,计算有多少wins
:
val distDF = df.groupBy($"data").agg(sum($"censorFlag").as("wins"))
然后,正如您所指出的,我们可以构建垃圾箱的数据框:
val distinctData = df.select($"data".as("dataBins")).distinct()
然后加入>=
条件:
val atRiskCounts = distDF.join(distinctData, distDF.data >= distinctData.dataBins)
.groupBy($"data", $"wins")
.count()
答案 1 :(得分:1)
当您的要求检查列中包含该列中所有其余值的值时,集合是最重要的。当需要检查所有值时,可以确定该列的所有数据都需要在一个执行程序或驱动程序中累积。当你的要求存在时,你无法避免这一步骤。
现在主要部分是如何定义其余步骤以从火花的并行化中受益。我建议你broadcast
收集的集合(仅作为一列的不同数据,因此它们不能很大)并使用udf
函数来检查{{ 1}}条件如下
首先你可以优化你的收集步骤
gte
然后你import org.apache.spark.sql.functions._
val collectedData = df.select(sort_array(collect_set("data"))).collect()(0)(0).asInstanceOf[collection.mutable.WrappedArray[Double]]
收集的集
broadcast
下一步是定义val broadcastedArray = sc.broadcast(collectedData)
函数并检查udf
条件并返回gte
counts
并将其用作
def checkingUdf = udf((data: Double)=> broadcastedArray.value.count(x => x >= data))
所以最后你应该
distDF.withColumn("atRisk", checkingUdf(col("data"))).show(false)
我希望这是必需的+----+----------+----+------+
|data|censorFlag|wins|atRisk|
+----+----------+----+------+
|4.5 |1 |1 |1 |
|0.7 |0 |0 |6 |
|2.3 |0 |0 |3 |
|1.0 |1 |1 |4 |
|0.8 |1 |2 |5 |
|0.8 |1 |2 |5 |
|4.0 |1 |1 |2 |
+----+----------+----+------+
答案 2 :(得分:1)
我尝试了上面的例子(虽然不是最严格的!),似乎左join
效果最好。
数据:
import org.apache.spark.mllib.random.RandomRDDs._
val df = logNormalRDD(sc, 1, 3.0, 10000, 100).zip(uniformRDD(sc, 10000, 100).map(x => if(x <= 0.4) 1 else 0)).toDF("data", "censorFlag").withColumn("data", round(col("data"), 2))
联接示例:
def runJoin(sc: SparkContext, df:DataFrame): Unit = {
val bins = df.select(col("data").as("dataBins")).distinct().sort("dataBins")
val wins = df.groupBy(col("data")).agg(sum("censorFlag").as("wins"))
val atRiskCounts = bins.join(df, bins("dataBins") <= df("data")).groupBy("dataBins").count().withColumnRenamed("count", "atRisk")
val finalDF = wins.join(atRiskCounts, wins("data") === atRiskCounts("dataBins")).select("data", "wins", "atRisk").sort("data")
finalDF.show
}
广播示例:
def runBroadcast(sc: SparkContext, df: DataFrame): Unit = {
val bins = df.select(sort_array(collect_set("data"))).collect()(0)(0).asInstanceOf[collection.mutable.WrappedArray[Double]]
val binsBroadcast = sc.broadcast(bins)
val df2 = binsBroadcast.value.map(x => (x, df.filter(col("data").geq(x)).select(count(col("data"))).as[Long].first)).toDF("data", "atRisk")
val finalDF = df.groupBy(col("data")).agg(sum("censorFlag").as("wins")).join(df2, "data")
finalDF.show
binsBroadcast.destroy
}
测试代码:
var start = System.nanoTime()
runJoin(sc, sampleDF)
val joinTime = TimeUnit.SECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS)
start = System.nanoTime()
runBroadcast(sc, sampleDF)
val broadTime = TimeUnit.SECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS)
我针对不同大小的随机数据运行此代码,提供了手动bins
数组(一些非常精细,50%的原始不同数据,一些非常小,10%的原始不同数据),并且始终如一似乎join
方法是最快的(虽然两者都达到了相同的解决方案,所以这是一个加分!)。
平均而言,我发现bin数组越小,broadcast
方法效果越好,但join
似乎不会受到太大影响。如果我有更多的时间/资源来测试这个,我会进行大量的模拟,看看平均运行时间是什么样的,但是现在我接受了@hoyland的解决方案。
仍然不确定为什么原始方法不起作用,所以愿意接受评论。
请告诉我代码中的任何问题或改进!谢谢你们:)