此代码创建一个整数的rdd并打印它们:
schema = StructType([StructField('value', IntegerType(), False)])
rdd = sc.parallelize([[100],[50],[25]])
myrdd = sqlContext.createDataFrame(rdd, schema).rdd
for x in myrdd.collect():
print(x)
打印:
Row(value=100)
Row(value=50)
Row(value=25)
我正在尝试从此rdd
中减去一个值,以便每次减法时如果有一个余数则从下一行中减去。
作为减去125的例子,从第一行取100,从第二行取25,这将留下一个新的rdd值:
Row(value=0)
Row(value=25)
Row(value=25)
作为减去160的另一个例子,从第一行获取100,从第二行获取50,从第三行获取10,这将留下新的rdd值:
Row(value=0)
Row(value=0)
Row(value=15)
我的尝试:
valueToRemove = 125
def myFun(s):
valueToRemove = valueToRemove - s['value']
return Row(value = valueToRemove)
myrdd1 = myrdd.map(myFun)
for x in myrdd1.collect():
print(x)
导致错误:
UnboundLocalError: local variable 'valueToRemove' referenced before assignment
我认为一个自然的解决方案是foldLeft
,但Apache spark不支持foldLeft
。此外,我无法使用fold
,因为要按确定的顺序处理行。
如何从每一行中减去一个值并存储要在下一行中使用的减法结果?
更新:
添加全局:
schema = StructType([StructField('value', IntegerType(), False)])
rdd = sc.parallelize([[100],[50],[25]])
myrdd = sqlContext.createDataFrame(rdd, schema).rdd
for x in myrdd.collect():
print(x)
global valueToRemove
valueToRemove = 125
def myFun(s):
valueToRemove = valueToRemove - s['value']
return Row(value = valueToRemove)
myrdd1 = myrdd.map(myFun)
for x in myrdd1.collect():
print(x)
导致同样的错误。
答案 0 :(得分:1)
我解决了它假设:
根据上述假设,这是我输入的版本
schema = StructType([StructField('row', IntegerType(),
False),StructField('value', IntegerType(), False)])
rdd = sc.parallelize([[1, 100],[2, 50],[3, 25],[4,225]])
myrdd = sqlContext.createDataFrame(rdd, schema)
for x in myrdd.collect():
print(x)
打印:
Row(row=1, value=100)
Row(row=2, value=50)
Row(row=3, value=25)
Row(row=4, value=225)
首先添加累积总和列:
from pyspark.sql.window import Window
import pyspark.sql.functions as F
w = Window.orderBy("row")
tempDF = myrdd.select("value","row",F.sum("value").over(w).alias("cumsum"))
tempDF.show()
打印:
+-----+---+------+
|value|row|cumsum|
+-----+---+------+
| 100| 1| 100|
| 50| 2| 150|
| 25| 3| 175|
| 225| 4| 400|
+-----+---+------+
最后我定义了一个UDF来计算新值:
def new_val(cumsum_val, row_val, target_val):
if cumsum_val - row_val >= target_val:
#rows that are after the "affected area"
return row_val
if cumsum_val - target_val < 0:
# rows that use all their values
return 0
# rows with reminders
return cumsum_val - target_val
new_val_udf = F.udf(new_val)
value = 160
tempDF.withColumn("new_val",new_val_udf(F.col("cumsum"), F.col("value"), F.lit(value))).show()
输出结果为:
+-----+---+------+-------+
|value|row|cumsum|new_val|
+-----+---+------+-------+
| 100| 1| 100| 0|
| 50| 2| 150| 0|
| 25| 3| 175| 15|
| 225| 4| 400| 225|
+-----+---+------+-------+