我正在用Python编写一个GroupIntoBatches转换,模仿在JAVA中的对应转换。
class GroupIntoBatches(PTransform):
def __init__(self, batch_size):
self.batch_size = batch_size
@staticmethod
def of_size(batch_size):
return GroupIntoBatches(batch_size)
def expand(self, pcoll):
input_coder = coders.registry.get_coder(pcoll)
if not input_coder.is_kv_coder():
raise ValueError('coder specified in the input \
PCollection is not a KvCoder')
return pcoll | ParDo(_pardo_group_into_batches(self.batch_size, input_coder))
def _pardo_group_into_batches(batch_size, input_coder):
ELEMENT_STATE = BagStateSpec('values', input_coder)
COUNT_STATE = CombiningValueStateSpec('count', input_coder, CountCombineFn())
EXPIRY_TIMER = TimerSpec('expiry', TimeDomain.WATERMARK)
class _GroupIntoBatchesDoFn(DoFn):
def process(self, element,
window=DoFn.WindowParam,
element_state=DoFn.StateParam(ELEMENT_STATE),
count_state=DoFn.StateParam(COUNT_STATE),
expiry_timer=DoFn.TimerParam(EXPIRY_TIMER)):
# Allowed lateness not supported in Python SDK
# https://beam.apache.org/documentation/programming-guide/#watermarks-and-late-data
expiry_timer.set(window.max_timestamp())
element_state.add(element)
count_state.add(1)
count = count_state.read()
if count >= batch_size:
batch = [element for element in element_state.read()]
yield batch
element_state.clear()
count_state.clear()
@on_timer(EXPIRY_TIMER)
def expiry(self, element_state=DoFn.StateParam(ELEMENT_STATE), count_state=DoFn.StateParam(COUNT_STATE)):
batch = [element for element in element_state.read()]
yield batch
element_state.clear()
count_state.clear()
return _GroupIntoBatchesDoFn()
问题出在此代码执行的中间,我收到错误
KeyError: '\ x1a; \ n-GroupIntoBatches / ParDo(_GroupIntoBatchesDoFn)\ x12 \ x05count“ \ x03key'
经过一些调试后,我意识到这是由于expiry
函数的最后一行count_state.clear()
引起的。此时,跑步者将状态对象删除。
我目前对BEAM不太了解,因此,如果我在这里遗漏了某些东西或做错了什么,我正在寻找可以使我朝正确方向前进的人。
这是JAVA的对应版本[https://github.com/apache/beam/blob/11a977b8b26eff2274d706541127c19dc93131a2/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java 1