我正在尝试使用PySpark中的spark数据帧实现Minhash算法,但是如果不使用collect()和广播,就无法实现给定的代码。
df.show()
+-----+--------+--------------------+--------------------+
|index|filename| docShingles| Signatures|
+-----+--------+--------------------+--------------------+
| 0|amem.txt|[[698253355, 2257...|[21296291, 108712...|
| 1|amwh.txt|[[698253355, 1308...|[70755, 5583, 214...|
| 2|army.txt|[[696890156, 6669...|[91782, 151784, 1...|
| 3|aunt.txt|[[1783506231, 167...|[105504, 29647, 3...|
| 4|bart.txt|[[3642870898, 364...|[3465174, 1205637...|
| 5|test.txt|[[698253355, 1308...|[70755, 5583, 214...|
+-----+--------+--------------------+--------------------+
我想要的是
estJSim.show()
+-----+----------+--------+--------+--------------------+--------------------+
|Index|MinHashVal| filenm1| filenm2| lst1| lst2|
+-----+----------+--------+--------+--------------------+--------------------+
| 0| 0.0|amem.txt|amwh.txt|[21296291, 108712...|[70755, 5583, 214...|
| 1| 0.0|amem.txt|army.txt|[21296291, 108712...|[91782, 151784, 1...|
| 2| 0.0|amem.txt|aunt.txt|[21296291, 108712...|[105504, 29647, 3...|
| 3| 0.0|amem.txt|bart.txt|[21296291, 108712...|[3465174, 1205637...|
| 4| 0.0|amem.txt|test.txt|[21296291, 108712...|[70755, 5583, 214...|
| 5| 0.0|amwh.txt|army.txt|[70755, 5583, 214...|[91782, 151784, 1...|
| 6| 0.0|amwh.txt|aunt.txt|[70755, 5583, 214...|[105504, 29647, 3...|
| 7| 0.0|amwh.txt|bart.txt|[70755, 5583, 214...|[3465174, 1205637...|
| 8| 0.0|amwh.txt|test.txt|[70755, 5583, 214...|[70755, 5583, 214...|
| 9| 0.0|army.txt|aunt.txt|[91782, 151784, 1...|[105504, 29647, 3...|
| 10| 0.0|army.txt|bart.txt|[91782, 151784, 1...|[3465174, 1205637...|
| 11| 0.0|army.txt|test.txt|[91782, 151784, 1...|[70755, 5583, 214...|
| 12| 0.0|aunt.txt|bart.txt|[105504, 29647, 3...|[3465174, 1205637...|
| 13| 0.0|aunt.txt|test.txt|[105504, 29647, 3...|[70755, 5583, 214...|
| 14| 0.0|bart.txt|test.txt|[3465174, 1205637...|[70755, 5583, 214...|
+-----+----------+--------+--------+--------------------+--------------------+
下面是我用来实现此目的的代码。
i = 0
j = 0
ilist = []
jlist = []
while (i < numDocs-2) :
if j == numDocs-1 :
i = i + 1
j = i + 1
else :
j = j + 1
ilist.append(i)
jlist.append(j)
bilist = sc.broadcast(ilist)
bjlist = sc.broadcast(jlist)
estJSimArr = [(x, 0.0) for x in range(numElems)]
estJSimSchema = StructType([StructField("Index", IntegerType()), StructField("MinHashVal", DoubleType())])
estJSim = sql_context.createDataFrame(data=estJSimArr, schema=estJSimSchema)
filenms = df.select(df.filename).rdd.collect()
bfilenms = sc.broadcast(filenms)
udf_filenm1 = udf(lambda row: bfilenms.value[bilist.value[row]].filename, StringType())
estJSim = estJSim.withColumn("filenm1", udf_filenm1(estJSim["Index"]))
udf_filenm2 = udf(lambda row: bfilenms.value[bjlist.value[row]].filename, StringType())
estJSim = estJSim.withColumn("filenm2", udf_filenm2(estJSim["Index"]))
signs = df.select(df.Signatures).rdd.collect()
bsigns = sc.broadcast(signs)
udf_col1 = udf(lambda row: bsigns.value[bilist.value[row]].Signatures, ArrayType(IntegerType()))
estJSim = estJSim.withColumn("lst1", udf_col1(estJSim["Index"]))
udf_col2 = udf(lambda row: bsigns.value[bjlist.value[row]].Signatures, ArrayType(IntegerType()))
estJSim = estJSim.withColumn("lst2", udf_col2(estJSim["Index"]))
但是问题是,如果不使用collect()和广播,我将无法实现这一目标。可以在不破坏并行性的情况下实现这一点。 在此先感谢:)