我正在写一个Spark作业,它接收来自多个源的数据,过滤错误的输入行,并输出稍微修改过的输入版本。这项工作有两个额外的要求:
这项工作似乎很简单,我使用累加器来解决问题,以跟踪每个源的过滤行数。但是,当我实现最终from pyspark.sql import Row, SparkSession
from pyspark.sql.types import *
from random import randint
def filter_and_transform_parts(rows, filter_int, accum):
for r in rows:
if r[0] == filter_int:
accum.add(1)
continue
yield r[0], r[1] + 1, r[2] + 1
def main():
spark= SparkSession \
.builder \
.appName("Test") \
.getOrCreate()
sc = spark.sparkContext
accum = sc.accumulator(0)
# 20 inputs w/ tuple having 4 as first element
inputs = [(4, randint(1, 10), randint(1, 10)) if x % 5 == 0 else (randint(6, 10), randint(6, 10), randint(6, 10)) for x in xrange(100)]
rdd = sc.parallelize(inputs)
# filter out tuples where 4 is first element
rdd = rdd.mapPartitions(lambda r: filter_and_transform_parts(r, 4, accum))
# if not limit, accumulator value is 20
# if limit and limit_count <= 63, accumulator value is 0
# if limit and limit_count >= 64, accumulator value is 20
limit = True
limit_count = 63
if limit:
rdd = rdd.map(lambda r: Row(r[0], r[1], r[2]))
df_schema = StructType([StructField("val1", IntegerType(), False),
StructField("val2", IntegerType(), False),
StructField("val3", IntegerType(), False)])
df = spark.createDataFrame(rdd, schema=df_schema)
df = df.limit(limit_count)
df.write.mode("overwrite").csv('foo/')
else:
rdd.saveAsTextFile('foo/')
print "Accum value: {}".format(accum.value)
if __name__ == "__main__":
main()
时,我的累加器行为发生了变化。这是一些在单个来源上触发行为的条带化示例代码:
rdd.persist()
问题是我的累加器有时会报告已过滤行的数量,有时不会报告,具体取决于指定的限制和源的输入数量。但是,在所有情况下,过滤后的行都不会进入输出,这意味着过滤器发生了,累加器应该有一个值。
如果你能对此有所了解,那就非常有帮助,谢谢!
更新:
mapPartitions
之后添加stage
└ Makefile
└ terraform.tfvars
└ vpc
└ services
└ frontend-app
└ backend-app
└ vars.tf
└ outputs.tf
└ main.tf
└ data-storage
└ mysql
└ redis
调用使累加器行为保持一致。答案 0 :(得分:1)
实际上,limit_count
的价值并不重要。
某些时间Accum value
为0的原因是因为您在转换中执行累加器(例如:rdd.map,rdd.mapPartitions)。
Spark只保证累积器在动作中也能正常工作(例如:rdd.foreach)
让我们对您的代码进行一些更改:
from pyspark.sql import *
from random import randint
def filter_and_transform_parts(rows, filter_int, accum):
for r in rows:
if r[0] == filter_int:
accum.add(1)
def main():
spark = SparkSession.builder.appName("Test").getOrCreate()
sc = spark.sparkContext
print(sc.applicationId)
accum = sc.accumulator(0)
inputs = [(4, x * 10, x * 100) if x % 5 == 0 else (randint(6, 10), x * 10, x * 100) for x in xrange(100)]
rdd = sc.parallelize(inputs)
rdd.foreachPartition(lambda r: filter_and_transform_parts(r, 4, accum))
limit = True
limit_count = 10 or 'whatever'
if limit:
rdd = rdd.map(lambda r: Row(val1=r[0], val2=r[1], val3=r[2]))
df = spark.createDataFrame(rdd)
df = df.limit(limit_count)
df.write.mode("overwrite").csv('file:///tmp/output')
else:
rdd.saveAsTextFile('file:///tmp/output')
print "Accum value: {}".format(accum.value)
if __name__ == "__main__":
main()
累积值始终等于20
了解更多信息:
http://spark.apache.org/docs/2.0.2/programming-guide.html#accumulators