在Pyspark UDF中使用自定义Python对象

时间:2017-10-11 15:38:46

标签: python apache-spark pyspark apache-spark-sql

运行以下PySpark代码时:

nlp = NLPFunctions()

def parse_ingredients(ingredient_lines):
    parsed_ingredients = nlp.getingredients_bulk(ingredient_lines)[0]
    return list(chain.from_iterable(parsed_ingredients))


udf_parse_ingredients = UserDefinedFunction(parse_ingredients, ArrayType(StringType()))

我收到以下错误: _pickle.PicklingError: Could not serialize object: TypeError: can't pickle _thread.lock objects

我想这是因为PySpark无法序列化这个自定义类。但是,如何避免在parse_ingredients_line函数的每次运行中实例化这个昂贵的对象的开销?

3 个答案:

答案 0 :(得分:1)

假设您要使用Identity这样定义的类identity.py):

class Identity(object):                   
    def __getstate__(self):
        raise NotImplementedError("Not serializable")

    def identity(self, x):
        return x

例如,您可以使用可调用对象(f.py)并将Identity实例存储为类成员:

from identity import Identity

class F(object):                          
    identity = None

    def __call__(self, x):
        if not F.identity:
            F.identity = Identity()
        return F.identity.identity(x)

并使用如下所示:

from pyspark.sql.functions import udf
import f

sc.addPyFile("identity.py")
sc.addPyFile("f.py")

f_ = udf(f.F())

spark.range(3).select(f_("id")).show()
+-----+
|F(id)|
+-----+
|    0|
|    1|
|    2|
+-----+

或独立功能和闭包:

from pyspark.sql.functions import udf
import identity

sc.addPyFile("identity.py")

def f(): 
    dict_ = {}                 
    @udf()              
    def f_(x):                 
        if "identity" not in dict_:
            dict_["identity"] = identity.Identity()
        return dict_["identity"].identity(x)
    return f_


spark.range(3).select(f()("id")).show()
+------+
|f_(id)|
+------+
|     0|
|     1|
|     2|
+------+

答案 1 :(得分:1)

我通过使NLPFunctions类的所有依赖项可序列化,基于(https://github.com/scikit-learn/scikit-learn/issues/6975)解决了它。

答案 2 :(得分:0)

编辑:这个答案错了​​。该对象仍然被序列化,然后在广播时反序列化,因此不能避免序列化。 (Tips for properly using large broadcast variables?

尝试使用broadcast variable

sc = SparkContext()
nlp_broadcast = sc.broadcast(nlp) # Stores nlp in de-serialized format.

def parse_ingredients(ingredient_lines):
    parsed_ingredients = nlp_broadcast.value.getingredients_bulk(ingredient_lines)[0]
    return list(chain.from_iterable(parsed_ingredients))