Spark广播了训练有素的张量流SavedModel

时间:2020-09-17 00:03:51

标签: python apache-spark tensorflow

我有一个训练有素的SavedModel。我正在尝试在Spark中广播已加载的模型,但从pyspark / broadcast.py中得到此错误-

raise pickle.PicklingError(msg)
_pickle.PicklingError: Could not serialize broadcast: TypeError: can't pickle _thread.RLock objects

我必须加载和广播的代码-

import tensorflow as tf
from pyspark.sql import SparkSession

spark = (
    SparkSession
    .builder
    .getOrCreate()
    )

model = tf.keras.models.load_model(saved_model_path, compile=False)
spark.sparkContext.broadcast(model) #<--- this is where it fails

我试图腌制该模型以进行验证,并且正如预期的那样,它也产生了错误。

import pickle
with open("model.pkl", 'wb') as f:
    pickle.dump(model, f)

在我看来,不能腌制SavedModel。类似的代码对于h5模型也可以正常工作。现在,如果无法腌制SavedModel并需要对其进行广播,那么我有什么选择?

谢谢。

1 个答案:

答案 0 :(得分:0)

您可以使用 sparkFiles 的概念将模型文件发送到所有工作节点,然后从 pandas udf 内部加载模型

from pyspark.sql import SparkSession

spark = (
    SparkSession
    .builder
    .getOrCreate()
    )


spark.sparkContext.addFile(saved_model_path)


@pandas_udf(return_type, PandasUDFType.GROUPED_MAP)
def predict(data):

    from tensorflow as tf

    model_file_local = SparkFiles.get(filename)
    model = tf.keras.models.load_model(model_file_local, compile=False)
    model.predict(data)