运行以下PySpark代码时:
nlp = NLPFunctions()
def parse_ingredients(ingredient_lines):
parsed_ingredients = nlp.getingredients_bulk(ingredient_lines)[0]
return list(chain.from_iterable(parsed_ingredients))
udf_parse_ingredients = UserDefinedFunction(parse_ingredients, ArrayType(StringType()))
我收到以下错误:
_pickle.PicklingError: Could not serialize object: TypeError: can't pickle _thread.lock objects
我想这是因为PySpark无法序列化这个自定义类。但是,如何避免在parse_ingredients_line
函数的每次运行中实例化这个昂贵的对象的开销?
答案 0 :(得分:1)
假设您要使用Identity
这样定义的类identity.py
):
class Identity(object):
def __getstate__(self):
raise NotImplementedError("Not serializable")
def identity(self, x):
return x
例如,您可以使用可调用对象(f.py
)并将Identity
实例存储为类成员:
from identity import Identity
class F(object):
identity = None
def __call__(self, x):
if not F.identity:
F.identity = Identity()
return F.identity.identity(x)
并使用如下所示:
from pyspark.sql.functions import udf
import f
sc.addPyFile("identity.py")
sc.addPyFile("f.py")
f_ = udf(f.F())
spark.range(3).select(f_("id")).show()
+-----+
|F(id)|
+-----+
| 0|
| 1|
| 2|
+-----+
或独立功能和闭包:
from pyspark.sql.functions import udf
import identity
sc.addPyFile("identity.py")
def f():
dict_ = {}
@udf()
def f_(x):
if "identity" not in dict_:
dict_["identity"] = identity.Identity()
return dict_["identity"].identity(x)
return f_
spark.range(3).select(f()("id")).show()
+------+
|f_(id)|
+------+
| 0|
| 1|
| 2|
+------+
答案 1 :(得分:1)
我通过使NLPFunctions类的所有依赖项可序列化,基于(https://github.com/scikit-learn/scikit-learn/issues/6975)解决了它。
答案 2 :(得分:0)
编辑:这个答案错了。该对象仍然被序列化,然后在广播时反序列化,因此不能避免序列化。 (Tips for properly using large broadcast variables?)
尝试使用broadcast variable。
sc = SparkContext()
nlp_broadcast = sc.broadcast(nlp) # Stores nlp in de-serialized format.
def parse_ingredients(ingredient_lines):
parsed_ingredients = nlp_broadcast.value.getingredients_bulk(ingredient_lines)[0]
return list(chain.from_iterable(parsed_ingredients))