如何系统地重用 Dask 中延迟函数的结果?

时间:2021-04-14 08:50:28

标签: python dask dask-delayed

我正在使用 Dask 构建计算图。一些中间值将被多次使用,但我希望这些计算只运行一次。我一定是犯了一个小错误,因为事实并非如此。这是一个最小的例子:

In [1]:    import dask
           dask.__version__
    
Out [1]:   '1.0.0'

In [2]:   class SumGenerator(object):
              def __init__(self):
                  self.sources = []
    
              def register(self, source):
                  self.sources += [source]
        
              def generate(self):
                  return dask.delayed(sum)([s() for s in self.sources])

In [3]:    sg = SumGenerator()

In [4]:    @dask.delayed
           def source1():
               return 1.

           @dask.delayed
           def source2():
               return 2.

           @dask.delayed
           def source3():
               return 3.

In [5]:    sg.register(source1)
           sg.register(source1)
           sg.register(source2)
           sg.register(source3)

In [6]:    sg.generate().visualize()

遗憾的是,我无法发布生成的图形图像,但基本上我看到函数 source1 的两个单独节点已注册两次。因此该函数被调用两次。我宁愿让它调用一次,结果记住并在总和中添加两次。这样做的正确方法是什么?

1 个答案:

答案 0 :(得分:3)

您需要通过传递 dask.delayed 参数来调用 pure=True 装饰器。

来自dask delayed docs

<块引用>

delayed 也接受一个可选的关键字 pure。如果为 False,则后续调用将始终产生不同的 Delayed

如果你知道一个函数是纯函数(输出只取决于输入,没有全局状态),那么你可以设置 pure=True。

所以使用它

import dask

class SumGenerator(object):
    def __init__(self):
        self.sources = []

    def register(self, source):
        self.sources += [source]

    def generate(self):
        return dask.delayed(sum)([s() for s in self.sources])

@dask.delayed(pure=True)
def source1():
    return 1.

@dask.delayed(pure=True)
def source2():
    return 2.

@dask.delayed(pure=True)
def source3():
    return 3.

sg = SumGenerator()

sg.register(source1)
sg.register(source1)
sg.register(source2)
sg.register(source3)

sg.generate().visualize()

输出和图表

Graph

使用 print(dask.compute(sg.generate())) 得到 (7.0,),它与您编写的相同,但没有图像中看到的额外节点。

相关问题