从任务

时间:2015-07-28 18:54:01

标签: python scala apache-spark pyspark apache-spark-mllib

背景

我原来的问题是为什么在map函数中使用DecisionTreeModel.predict会引发异常?并与How to generate tuples of (original lable, predicted label) on Spark with MLlib?

相关

当我们使用Scala API a recommended way使用RDD[LabeledPoint]获取DecisionTreeModel的预测时,只需映射RDD

val labelAndPreds = testData.map { point =>
  val prediction = model.predict(point.features)
  (point.label, prediction)
}

不幸的是,PySpark中的类似方法效果不好:

labelsAndPredictions = testData.map(
    lambda lp: (lp.label, model.predict(lp.features))
labelsAndPredictions.first()
  

异常:您似乎正在尝试从广播变量,操作或转换引用SparkContext。 SparkContext只能在驱动程序上使用,而不能在工作程序上运行的代码中使用。有关详细信息,请参阅SPARK-5063

而不是official documentation建议这样的事情:

predictions = model.predict(testData.map(lambda x: x.features))
labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)

那么这里发生了什么?这里没有广播变量,Scala API定义predict如下:

/**
 * Predict values for a single data point using the model trained.
 *
 * @param features array representing a single data point
 * @return Double prediction from the trained model
 */
def predict(features: Vector): Double = {
  topNode.predict(features)
}

/**
 * Predict values for the given data set using the model trained.
 *
 * @param features RDD representing data points to be predicted
 * @return RDD of predictions for each of the given data points
 */
def predict(features: RDD[Vector]): RDD[Double] = {
  features.map(x => predict(x))
}

所以至少乍一看,从行动或转变中调用并不是一个问题,因为预测似乎是一个本地操作。

解释

经过一番挖掘后,我发现问题的根源是JavaModelWrapper.call调用的DecisionTreeModel.predict方法。调用Java函数需要access SparkContext

callJavaFunc(self._sc, getattr(self._java_model, name), *a)

问题

DecisionTreeModel.predict的情况下,有一个推荐的解决方法,并且所有必需的代码已经是Scala API的一部分,但是有没有优雅的方法来处理这样的问题?

我现在只能想到的解决方案是相当重量级的:

  • 通过Implicit Conversions扩展Spark类或添加某种包装器将所有内容推送到JVM
  • 直接使用Py4j网关

1 个答案:

答案 0 :(得分:42)

使用默认Py4J网关的通信根本不可能。要理解为什么我们必须从PySpark Internals文档[1]中查看下图:

enter image description here

由于Py4J网关在驱动程序上运行,因此Python解释器无法通过套接字与JVM工作者进行通信(例如参见PythonRDD / rdd.py)。

理论上可以为每个工作者创建一个单独的Py4J网关,但实际上它不太可能有用。忽略可靠性等问题Py4J根本不是为执行数据密集型任务而设计的。

有任何变通方法吗?

  1. 使用Spark SQL Data Sources API包装JVM代码。

    优点:支持,高级别,不需要访问内部PySpark API

    缺点:相对冗长且记录不充分,主要限于输入数据

  2. 使用Scala UDF在DataFrame上运行。

    优点:易于实施(请参阅Spark: How to map Python with Scala or Java User Defined Functions?),如果数据已存储在DataFrame中,则Python与Scala之间无数据转换,对Py4J的访问权限最小

    缺点:需要访问Py4J网关和内部方法,仅限于Spark SQL,难以调试,不支持

  3. 以与MLlib相同的方式创建高级Scala接口。

    优点:灵活,能够执行任意复杂的代码。它可以直接在RDD上(例如参见MLlib model wrappers)或DataFrames(参见How to use a Scala class inside Pyspark)。后一种解决方案似乎更加友好,因为所有服务细节都已由现有API处理。

    缺点:低级别,必需的数据转换,与UDF相同,需要访问Py4J和内部API,不支持

    可以在Transforming PySpark RDD with Scala

  4. 中找到一些基本示例
  5. 使用外部工作流管理工具在Python和Scala / Java作业之间切换,并将数据传递给DFS。

    优点:易于实施,对代码本身的修改最少

    缺点:读取/写入数据的成本(Alluxio?)

  6. 使用共享SQLContext(请参阅例如Apache ZeppelinLivy)使用已注册的临时表在来宾语言之间传递数据。

    优点:非常适合互动分析

    缺点:与批处理作业(Zeppelin)不同,或者可能需要额外的编排(Livy)

    1. 约书亚罗森。 (2014年8月4日)PySpark Internals。取自https://cwiki.apache.org/confluence/display/SPARK/PySpark+Internals