无法从pyspark RDD的地图方法

时间:2017-09-12 13:47:11

标签: python pyspark rdd

在我的应用程序的代码库中集成pyspark时,我无法在RDD的map方法中引用类的方法。我用一个简单的例子重复了这个问题,如下所示

这是一个虚拟类,我已经定义了它只是为从RDD派生的RDD的每个元素添加一个数字,这是一个类属性:

class Test:

    def __init__(self):
        self.sc = SparkContext()
        a = [('a', 1), ('b', 2), ('c', 3)]
        self.a_r = self.sc.parallelize(a)

    def add(self, a, b):
        return a + b

    def test_func(self, b):
        c_r = self.a_r.map(lambda l: (l[0], l[1] * 2))
        v = c_r.map(lambda l: self.add(l[1], b))
        v_c = v.collect()
        return v_c

test_func()调用RDD map()上的v方法,然后在add()的每个元素上调用v方法。调用test_func()会引发以下错误:

pickle.PicklingError: Could not serialize object: Exception: It appears that you are attempting to reference SparkContext from a broadcast variable, action, or transformation. SparkContext can only be used on the driver, not in code that it run on workers. For more information, see SPARK-5063.

现在,当我将add()方法移出课堂时,如:

def add(self, a, b):
    return a + b

class Test:

    def __init__(self):
        self.sc = SparkContext()
        a = [('a', 1), ('b', 2), ('c', 3)]
        self.a_r = self.sc.parallelize(a)

    def test_func(self, b):

        c_r = self.a_r.map(lambda l: (l[0], l[1] * 2))
        v = c_r.map(lambda l: add(l[1], b))
        v_c = v.collect()

        return v_c

现在正在调用test_func()

[7, 9, 11]

为什么会发生这种情况?如何将类方法传递给RDD的map()方法?

1 个答案:

答案 0 :(得分:6)

这是因为当pyspark尝试序列化您的函数(将其发送给工作者)时,它还需要序列化您的Test类的实例(因为您传递给map的函数在self)中引用了这个实例。此实例引用了spark上下文。您需要确保SparkContextRDD未被序列化并发送给工作人员的任何对象引用。 SparkContext只需要在司机中生活。

这应该有效:

在档案testspark.py中:

class Test(object):
    def add(self, a, b):
        return a + b

    def test_func(self, a_r, b):
        c_r = a_r.map(lambda l: (l[0], l[1] * 2))
        # now `self` has no reference to the SparkContext()
        v = c_r.map(lambda l: self.add(l[1], b)) 
        v_c = v.collect()
        return v_c

在您的主脚本中:

from pyspark import SparkContext
from testspark import Test

sc = SparkContext()
a = [('a', 1), ('b', 2), ('c', 3)]
a_r = sc.parallelize(a)

test = Test()
test.test_func(a_r, 5) # should give [7, 9, 11]