如何在分布式dask数组中找到最小的n个值

时间:2019-06-28 21:24:12

标签: dask dask-distributed

我有一个形状为(2400,2400),块大小为(100,100)的分布式dask数组。我以为可以使用topk(-n)来找到最小的n个值。但是,它似乎返回一个形状为(2400,n)的数组,因此看起来好像在每一行中找到了最小的n。是否有一种方法可以使用topk获取所有行中的最小n值(整个数组)?

一个想法是两次调用topk,每个轴一次。

>>> dist
dask.array<pow, shape=(2400, 2400), dtype=float64, chunksize=(100, 100)>
>>> dist.topk(-5,axis=0).topk(-5,axis=1).compute()
array([[   0.        , 2620.09503644, 2842.15200157, 2955.08409356,
        3163.49458669],
       [3660.67698657, 3670.4457495 , 3700.09837707, 3717.09052889,
        4002.86497399],
       [4125.89820524, 4139.44658137, 4250.50420539, 4331.01304547,
        4402.14606754],
       [4328.22966119, 4378.25193428, 4507.94409903, 4522.4913488 ,
        4555.06860541],
       [4441.58755402, 4560.95625938, 4576.39333974, 4682.06215251,
        4765.11531865]])

1 个答案:

答案 0 :(得分:1)

  

一个想法是两次调用topk,每个轴一次。

对我来说听起来不错!

您可能会考虑先对数组进行展平,但是对于您已经发现的内容,我看不出有什么好处。

x.flatten().topk(...)