我试图将udf函数应用于由字符串组成的数据框列。函数使用Tensorflow GUSE并将字符串转换为浮点数数组。
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import tf_sentencepiece
# Graph set up.
g = tf.Graph()
with g.as_default():
text_input = tf.placeholder(dtype=tf.string, shape=[None])
embed = hub.Module("https://tfhub.dev/google/universal-sentence-encoder-multilingual-large/1")
embedded_text = embed(text_input)
init_op = tf.group([tf.global_variables_initializer(), tf.tables_initializer()])
g.finalize()
# Initialize session.
session = tf.Session(graph=g)
session.run(init_op)
def embed_mail(x):
embedding = session.run(embedded_text, feed_dict={text_input:[x]})
embedding = flatten(embedding)
result = [np.float32(i).item() for i in embedding]
return result
但是每当我尝试使用以下命令运行此功能时:
embed_mail_udf = udf(embed_mail, ArrayType(FloatType()))
df = df.withColumn('embedding',embed_mail_udf(df.text))
我不断收到错误消息:无法序列化对象:TypeError:无法腌制SwigPyObject对象。我在做什么错了?
答案 0 :(得分:1)
要在集群Spark上运行UDF的代码,需要能够序列化“附加”到该函数的所有数据。您的UDF embed_mail
包含对TF Session的引用,因此该函数是closure
,Spark首先需要序列化tf.Session的内容。我敢打赌这是问题的原因。不幸的是,我没有使用TF的经验,但是看来您可以在运行Spark之前从TF获取邮件数据,进行广播,然后在udf中使用?