使用PySpark数据帧实现MinHash算法时出现问题

时间:2019-04-13 21:06:11

标签: python pyspark minhash

我正在尝试使用PySpark中的spark数据帧实现Minhash算法,但是如果不使用collect()和广播,就无法实现给定的代码。

df.show()
+-----+--------+--------------------+--------------------+
|index|filename|         docShingles|          Signatures|
+-----+--------+--------------------+--------------------+
|    0|amem.txt|[[698253355, 2257...|[21296291, 108712...|
|    1|amwh.txt|[[698253355, 1308...|[70755, 5583, 214...|
|    2|army.txt|[[696890156, 6669...|[91782, 151784, 1...|
|    3|aunt.txt|[[1783506231, 167...|[105504, 29647, 3...|
|    4|bart.txt|[[3642870898, 364...|[3465174, 1205637...|
|    5|test.txt|[[698253355, 1308...|[70755, 5583, 214...|
+-----+--------+--------------------+--------------------+

我想要的是

estJSim.show()
+-----+----------+--------+--------+--------------------+--------------------+
|Index|MinHashVal| filenm1| filenm2|                lst1|                lst2|
+-----+----------+--------+--------+--------------------+--------------------+
|    0|       0.0|amem.txt|amwh.txt|[21296291, 108712...|[70755, 5583, 214...|
|    1|       0.0|amem.txt|army.txt|[21296291, 108712...|[91782, 151784, 1...|
|    2|       0.0|amem.txt|aunt.txt|[21296291, 108712...|[105504, 29647, 3...|
|    3|       0.0|amem.txt|bart.txt|[21296291, 108712...|[3465174, 1205637...|
|    4|       0.0|amem.txt|test.txt|[21296291, 108712...|[70755, 5583, 214...|
|    5|       0.0|amwh.txt|army.txt|[70755, 5583, 214...|[91782, 151784, 1...|
|    6|       0.0|amwh.txt|aunt.txt|[70755, 5583, 214...|[105504, 29647, 3...|
|    7|       0.0|amwh.txt|bart.txt|[70755, 5583, 214...|[3465174, 1205637...|
|    8|       0.0|amwh.txt|test.txt|[70755, 5583, 214...|[70755, 5583, 214...|
|    9|       0.0|army.txt|aunt.txt|[91782, 151784, 1...|[105504, 29647, 3...|
|   10|       0.0|army.txt|bart.txt|[91782, 151784, 1...|[3465174, 1205637...|
|   11|       0.0|army.txt|test.txt|[91782, 151784, 1...|[70755, 5583, 214...|
|   12|       0.0|aunt.txt|bart.txt|[105504, 29647, 3...|[3465174, 1205637...|
|   13|       0.0|aunt.txt|test.txt|[105504, 29647, 3...|[70755, 5583, 214...|
|   14|       0.0|bart.txt|test.txt|[3465174, 1205637...|[70755, 5583, 214...|
+-----+----------+--------+--------+--------------------+--------------------+

下面是我用来实现此目的的代码。

i = 0
j = 0
ilist = []
jlist = []
while (i < numDocs-2) :
    if j == numDocs-1 :
        i = i + 1
        j = i + 1
    else :
        j = j + 1
    ilist.append(i)
    jlist.append(j)
bilist = sc.broadcast(ilist)
bjlist = sc.broadcast(jlist)


estJSimArr = [(x, 0.0) for x in range(numElems)]
estJSimSchema = StructType([StructField("Index", IntegerType()), StructField("MinHashVal", DoubleType())])
estJSim = sql_context.createDataFrame(data=estJSimArr, schema=estJSimSchema)

filenms = df.select(df.filename).rdd.collect()
bfilenms = sc.broadcast(filenms)
udf_filenm1 = udf(lambda row: bfilenms.value[bilist.value[row]].filename, StringType())
estJSim = estJSim.withColumn("filenm1", udf_filenm1(estJSim["Index"]))
udf_filenm2 = udf(lambda row: bfilenms.value[bjlist.value[row]].filename, StringType())
estJSim = estJSim.withColumn("filenm2", udf_filenm2(estJSim["Index"]))

signs = df.select(df.Signatures).rdd.collect()
bsigns = sc.broadcast(signs)

udf_col1 = udf(lambda row: bsigns.value[bilist.value[row]].Signatures, ArrayType(IntegerType()))
estJSim = estJSim.withColumn("lst1", udf_col1(estJSim["Index"]))

udf_col2 = udf(lambda row: bsigns.value[bjlist.value[row]].Signatures, ArrayType(IntegerType()))
estJSim = estJSim.withColumn("lst2", udf_col2(estJSim["Index"]))

但是问题是,如果不使用collect()和广播,我将无法实现这一目标。可以在不破坏并行性的情况下实现这一点。 在此先感谢:)

0 个答案:

没有答案