DoFn构造了多少次?

时间:2019-09-09 16:22:51

标签: python apache-beam dataflow

我正在使用apache beam python SDK和Dataflow编写推理管道,以使用TensorFlow模型进行预测。我在DoFn中有预测步骤,但是我不想每次处理包时都必须加载模型,因为这非常昂贵。在文档here中,“如果需要,可以在工作线程上创建参数DoFn的新实例,并在该实例上调用DoFn.Setup方法。这可以通过反序列化或其他方式进行。PipelineRunner可以将DoFn实例重用于多个包。异常终止(通过引发Exception)的DoFn将永远不会被重用。”我注意到,如果我这样编写代码

class StatefulGetEmbeddingsDoFn(beam.DoFn):
    def __init__(self, model_dir):
         self.model = None # initialize
         self.model_dir = model_dir

    def process(self, element):
         if not self.model: # load model if model hasn't been loaded yet
             global i
             i += 1
             logging.info('Getting model: {}'.format(i))
             self.model = Model(saved_model_dir=self.model_dir)


         ids, b64 = element
         embeddings = self.model.predict(b64)

         res = [
            {
                'image': _id,
                'embeddings': embedding.tolist()
            } for _id, embedding in zip(ids, embeddings)
         ]
         return res

似乎每个工作人员都多次加载模型(我有大约30-40台计算机集群)。有没有一种方法可以防止模型被多次加载?我本来希望该DoFn在每台机器上只能构建一次,但是从日志中构建,似乎并非如此……

1 个答案:

答案 0 :(得分:0)

我知道这是一个比较老的问题,但是我最初的想法是使用setupstart_bundle方法。

https://beam.apache.org/releases/pydoc/2.22.0/apache_beam.transforms.core.html#apache_beam.transforms.core.DoFn.setup