获取PySpark中可见节点的数量

时间:2015-02-27 15:30:16

标签: python-2.7 apache-spark pyspark

我在PySpark中运行一些操作,最近增加了配置中的节点数(在Amazon EMR上)。然而,即使我将节点数量增加了两倍(从4到12),性能似乎也没有改变。因此,我想看看Spark是否可以看到新节点。

我正在调用以下函数:

sc.defaultParallelism
>>>> 2

但我认为这告诉我分配给每个节点的任务总数,而不是Spark可以看到的代码总数。

如何查看PySpark在群集中使用的节点数量?

5 个答案:

答案 0 :(得分:24)

在pyspark上,你仍然可以使用pyspark的py4j网桥调用scala getExecutorMemoryStatus API:

sc._jsc.sc().getExecutorMemoryStatus().size()

答案 1 :(得分:16)

sc.defaultParallelism只是一个提示。根据配置,它可能与节点数量无关。如果您使用带有分区计数参数但未提供分区计数参数的操作,则这是分区数。例如,sc.parallelize将从列表中生成新的RDD。您可以使用第二个参数告诉它在RDD中创建多少个分区。但此参数的默认值为sc.defaultParallelism

您可以在Scala API中获取sc.getExecutorMemoryStatus的执行者数量,但这不会在Python API中公开。

一般来说,建议RDD中的分区大约是执行程序的4倍。这是一个很好的提示,因为如果任务所花费的时间存在差异,这将使其均匀。例如,一些执行程序将处理5个更快的任务,而其他执行程序处理3个较慢的任务。

你不需要对此非常准确。如果你有一个粗略的想法,你可以去估计。就像你知道你的CPU少于200个,你可以说500个分区就可以了。

因此尝试使用此数量的分区创建RDD:

rdd = sc.parallelize(data, 500)     # If distributing local data.
rdd = sc.textFile('file.csv', 500)  # If loading data from a file.

如果您不控制RDD的创建,则在计算之前重新分配RDD:

rdd = rdd.repartition(500)

您可以使用rdd.getNumPartitions()检查RDD中的分区数。

答案 2 :(得分:3)

应该可以使用此方法获取群集中的节点数量(类似于上面的@ Dan&#39方法,但更短,效果更好!)。

sc._jsc.sc().getExecutorMemoryStatus().keySet().size()

答案 3 :(得分:1)

我发现有时我的会话被远程程序杀死,给出了一个奇怪的Java错误

Py4JJavaError: An error occurred while calling o349.defaultMinPartitions.
: java.lang.IllegalStateException: Cannot call methods on a stopped SparkContext.

我通过以下

避免了这种情况
def check_alive(spark_conn):
    """Check if connection is alive. ``True`` if alive, ``False`` if not"""
    try:
        get_java_obj = spark_conn._jsc.sc().getExecutorMemoryStatus()
        return True
    except Exception:
        return False

def get_number_of_executors(spark_conn):
    if not check_alive(spark_conn):
        raise Exception('Unexpected Error: Spark Session has been killed')
    try:
        return spark_conn._jsc.sc().getExecutorMemoryStatus().size()
    except:
        raise Exception('Unknown error')

答案 4 :(得分:1)

其他答案提供了获取执行者数量的方法。这是一种获取节点数的方法。这包括头部和工作节点。

s = sc._jsc.sc().getExecutorMemoryStatus().keys()
l = str(s).replace("Set(","").replace(")","").split(", ")

d = set()
for i in l:
    d.add(i.split(":")[0])
len(d)