我想写一些在网格上执行以下操作的代码。 h = M * x,其中M是一些矩阵,x是向量。当然,我需要独立生成数十亿个随机x,收集输出h并将其用于分析。我想使用PySpark做到这一点。到目前为止,我已经运行了以下代码,但是担心它远非高效的实现。我不确定如何将广播的矩阵M放入函数calcMultPath中。我认为那部分不正确。
import findspark
findspark.init()
import pyspark
from pyspark.sql import SparkSession
from pyspark.context import SparkContext
from pyspark.sql.functions import udf, struct
from pyspark.sql.types import IntegerType
from pyspark.sql.types import StringType
from pyspark.sql.types import FloatType
from pyspark.sql.types import ArrayType
from pyspark.sql.functions import udf, col
from pyspark.sql import SQLContext
import os
import numpy as np
from numpy import genfromtxt
sc = SparkContext()
sqlContext = SQLContext(sc)
spark = SparkSession.builder.appName("myMCFunc").getOrCreate()
label_list = ["762876", "3098083", "0", "29000", "Info"]
NumScenarios=100
Seed=23729837
myTuplesList=[]
for i in range(0,NumScenarios):
s = (Seed, i)
myTuplesList.append(s)
a = sqlContext.createDataFrame(myTuplesList, ["Seed", "Path"])
print("------------------------ a1 --------------------------")
a.show()
def loadMatrix():
return np.random.randn(5,5)
def calcMultPath(feature_list, label=label_list):
M = loadMatrix() # how is this function known to the grid?
n = M.shape[0]
import random as rn
seed = 9763969
rn.seed(seed)
pathIndex = feature_list.Path
seed = feature_list.Seed
np.random.seed(seed)
if feature_list == 0:
return label[4]
# advance to starting point...we don't want the same x's all the time
for i in range(0, n * pathIndex):
x = np.random.randn()
x = np.random.randn(n)
h = np.matmul(M, x).transpose()
return [h.tolist()]
udfLoadMatrix = udf(loadMatrix, ArrayType(StringType())) # we need to plug this function into something, but how?
udfCalcMultPath = udf(calcMultPath, ArrayType(StringType()))
M = np.random.randn(5,5)
broadCastedM = sc.broadcast(M)
sqlContext.registerFunction("Matrix", lambda x: broadCastedM.value.get(x))
b = a.withColumn("udfResult", udfCalcMultPath(struct(["Seed","Path"]))) \
.withColumn("Output", col("udfResult")[0]).drop("udfResult")
print("----------------------------- b-------------------------------")
b.show()
print(b.collect())
spark.stop()