在工作者上加载本地(不可序列化)对象

时间:2017-08-17 20:11:31

标签: python google-cloud-dataflow apache-beam

我正在尝试将Dataflow与Tensorflow结合使用进行预测。这些预测发生在工人身上,我正在通过startup_bundle()加载模型。像这里:

class PredictDoFn(beam.DoFn): 
    def start_bundle(self):
        self.model = load_model_from_file()
    def process(self, element):
        ...

我目前的问题是,即使我处理了1000个元素,startup_bundle()函数也被多次调用(至少10次),而不是按照我希望的方式执行。这会显着减慢管道速度,因为模型需要多次加载,每次需要30秒。

有没有办法在初始化工作时加载模型,而不是每次都在start_bundle()

提前致谢! 迪米特里

1 个答案:

答案 0 :(得分:1)

最简单的方法是添加if self.model is None: self.model = load_model_from_file(),这可能不会减少模型重新加载的次数。

这是因为DoFn实例目前不在包中重复使用。这意味着在执行每个工作项后,您的模型将被忘记

您还可以在保留模型的位置创建global变量。这会减少重新加载的数量,但它实际上是非正统的(虽然它可以解决你的用例)。

全局变量方法应该是这样的:

class MyModelDoFn(object):
  def process(self, elem):
    global my_model
    if my_model is None:
      my_model = load_model_from_file()
    yield my_model.apply_to(elem)

依赖于线程局部变量的方法看起来如此。考虑到这将为每个线程加载一次模型,因此模型加载的次数取决于运行器实现(它将在Dataflow中工作):

class MyModelDoFn(object):
  _thread_local = threading.local()
  @property
  def model(self):
    model = getattr(MyModelDoFn._thread_local, 'model', None)
    if not model:
      MyModelDoFn._thread_local.model = load_model_from_file()

    return MyModelDoFn._thread_local.model

  def process(self, elem):
    yield self.model.apply_to(elem)

我猜您也可以从start_bundle来电加载模型。

注意:这种做法非常不正统,不保证可以在较新版本中使用,也不保证在所有版本中使用。