我有一个列表,其中所有值在列中都不同,我需要用1替换不在此列表中的所有值
我已经尝试过了
uniq = X_train3.select('street').distinct().collect()
X_test3 = X_test3.withColumn('street', F.when(array_contains('street', uniq), 1))
我也尝试过这个:
uniq = X_train3.select('street').distinct().collect()
X_test3 = X_test3.withColumn('street', F.when(~col('street').isin(uniq), 1))
均导致此错误: java.lang.RuntimeException:不支持的文字类型类java.util.ArrayList [[1.0]]
这是我在python中所做的工作:
uniq = X_train3[cl].unique()
uniq = uniq.tolist()
X_test3['street'] = X_test3['street'].map(lambda x: 1 if x not in uniq else x)]
答案 0 :(得分:-1)
您可以执行此操作(在Scala中,编写等效的pyspark):
val new_X_test3 = X_test3
.join(X_train3
.select("street")
.distinct()
.withColumnRenamed("street","street_train"),
col("street") === col("street_train"),
"leftouter")
.withColumn("street_test",
when(col("street_train").isNull, lit("1"))
.otherwise(col("street")))
.drop("street","street_train")
.withColumnRenamed("street_test","street")
如果您确信唯一街道列表非常小(因为您尝试在代码中将其收集到驱动程序中),则可以在broadcast
周围提供X_train3
提示。因此代码变为:
val new_X_test3 = X_test3
.join(broadcast(X_train3
.select("street")
.distinct()
.withColumnRenamed("street","street_train")),
col("street") === col("street_train"),
"leftouter")
.withColumn("street_test",
when(col("street_train").isNull, lit("1"))
.otherwise(col("street")))
.drop("street","street_train")
.withColumnRenamed("street_test","street")