pyspark中的矩阵乘法

时间:2017-02-24 05:48:14

标签: python apache-spark

我试图在pyspark中实现矩阵乘法,但似乎multiply()函数有问题,有人可以帮我吗? 谢谢!

    from pyspark import SparkContext, SparkConf
    import sys
    #define multiplication method
    def multiply(x):
        res = []
        for a in x[1]:
            if a[0] != 'A':
                continue
            else:
                for b in x[1]:
                    if b[0] != 'B':
                        continue
                    else:
                        res.union(((a[1], b[1]), (a[2]*b[2])))
        return res

    conf = SparkConf().setMaster("local[*]").setAppName("MatrixMultiplication")
    sc = SparkContext(conf = conf)

    matA = sc.textFile(sys.argv[1])
    matB = sc.textfile(sys.argv[2])


    matA = matA.map(lambda x: x.split(','))
    matA = matA.map(lambda x:(x[1], ('A', x[0], x[2])))

    matB = matB.map(lambda x: x.split(','))
    matB = matB.map(lambda x:(x[0], ('B', x[1], x[2])))

    mat = matA.union(matB)
    mat = mat.groupByKey()

    matC = mat.flatMap(multiply).reduceByKey(lambda x,y: x+y)
    output = matC.collect()
    f = open(sys.argv[3], 'w')
    for x in output:
        f.write(str(x[0][0]) + ',' + str(x[0][1])+'\t' + str(x[1]) + '\n')

0 个答案:

没有答案