我有一个训练有素的SavedModel。我正在尝试在Spark中广播已加载的模型,但从pyspark / broadcast.py中得到此错误-
raise pickle.PicklingError(msg)
_pickle.PicklingError: Could not serialize broadcast: TypeError: can't pickle _thread.RLock objects
我必须加载和广播的代码-
import tensorflow as tf
from pyspark.sql import SparkSession
spark = (
SparkSession
.builder
.getOrCreate()
)
model = tf.keras.models.load_model(saved_model_path, compile=False)
spark.sparkContext.broadcast(model) #<--- this is where it fails
我试图腌制该模型以进行验证,并且正如预期的那样,它也产生了错误。
import pickle
with open("model.pkl", 'wb') as f:
pickle.dump(model, f)
在我看来,不能腌制SavedModel。类似的代码对于h5模型也可以正常工作。现在,如果无法腌制SavedModel并需要对其进行广播,那么我有什么选择?
谢谢。
答案 0 :(得分:0)
您可以使用 sparkFiles 的概念将模型文件发送到所有工作节点,然后从 pandas udf 内部加载模型
from pyspark.sql import SparkSession
spark = (
SparkSession
.builder
.getOrCreate()
)
spark.sparkContext.addFile(saved_model_path)
@pandas_udf(return_type, PandasUDFType.GROUPED_MAP)
def predict(data):
from tensorflow as tf
model_file_local = SparkFiles.get(filename)
model = tf.keras.models.load_model(model_file_local, compile=False)
model.predict(data)