我正在运行dask,希望保存中间值。例如,当我运行时:
from distributed import Client
client = Client(IP_PORT_TO_SCHEDULER)
from dask import delayed
@delayed(pure=True)
def myfunction(a):
print("recomputing")
return a + 3
res = myfunction(1)
res2 = res**2
res3 = client.persist(res2)
resagain = res**3
resagain2 = client.persist(resagain)
我希望“重新计算”只打印一次。但是,在这种情况下,它打印两次。我想这可能是因为客户端没有缓存这个中间值。例如,运行client.has_what()
,我看到了:
{'tcp://xx.xx.xx.xx:xxxx': ['pow-9d66a68ce8be79ff9cca17a2dc58aa0b',
'pow-440784f1abedb14511aa0d633935b55a']}
我看到了幂函数的最终结果,但没有看到中间计算。有没有办法强制客户端存储这个中间计算?谢谢!
答案 0 :(得分:2)
Dask将保留您明确保留的所有结果。任何中间结果都将被清除,以节省内存。
因此,在您的情况下,您可能希望执行以下操作:
res = myfunction(1)
res = res.persist() # Ask Dask to keep this in memory
...