我有一个PySpark RDD,其中包含句子ID和向量:
Row(review_id='R1EGUNQON2J277_id_What', vector=DenseVector([0.033, 0.1455, -0.1428])),
Row(review_id='R1FIU59Z4UXOIT_id_I wanted the gift car4 for MUSIC not movies', vector=DenseVector([-0.0121, 0.1022, -0.0883])),
Row(review_id='R359VY6VMS5CKK_id_I had no idea that the amount is listed in US dollars', vector=DenseVector([0.0597, 0.0795, 0.1087])),
Row(review_id='R359VY6VMS5CKK_id_I ended purchasing US $100 instead of CAD $100, which meant extra $30CAD', vector=DenseVector([0.1173, 0.1267, -0.0521]))]
我将UDF用于余弦差和OOTB agg,min等功能,如下所示:
# Function for cosine similarity
def cos_sim(a,b):
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
# UDF to get cosine difference
def diff(row):
return row[0][0],row[1][0],1- cos_sim(row[0][1],row[1][1])
# Function to get the Center of Vectors
def get_center(pca_embeddings) :
# Get the cartesian products and calculate cosine difference between each pairs
cartesian_product = pca_embeddings.cartesian(pca_embeddings)
cosine_rdd = cartesian_product.map(diff)
new_df = cosine_rdd.map(lambda x: (x[0], x[1], x[2])).toDF(["A", "B" , "CosineDiff"])
# Get the Average cosine differences fro each Review A
w = Window().partitionBy(new_df["ReviewA"])
mean_df = (new_df.withColumn("mean", avg("CosineDiff").over(w)))
# Get the ranks of Reviews B for each Review A
window = Window.partitionBy(mean_df['A']).orderBy(mean_df['CosineDiff'])
rank_df = mean_df.select('*', rank().over(window).alias('rank'))
# Collect top 11 Review Bs for each Review A including itself
final_df = rank_df.filter(rank_df['rank']<12).groupBy('ReviewA').agg(F.collect_list("B").alias("Nearest Neighbours"),F.max("mean").alias('Mean Cosine'))
# Get the Review A with smallest Average Cosine Difference
final_row = final_df.select(F.min(
F.struct("Mean Cosine", *(x for x in final_df.columns if x != "Mean Cosine")))).first()
return final_row.asDict()['min(named_struct(NamePlaceholder(), Mean Cosine, NamePlaceholder(), ReviewA, NamePlaceholder(), Nearest Neighbours))'].asDict()