如何在多个dask数组之间共享相同的索引

时间:2019-01-13 13:31:22

标签: dask

我正在尝试构建一个基于dask的ipython应用程序,该应用程序包含一个元类,该元类包含一些子dask数组(它们都是成形的(n_samples,dim_1,dim_2 ...)),应该能够通过其 getitem 运算符对子任务数组进行分区。

getitem 方法中,我调用da.Array.compute方法(代码仍处于非常早期的状态),因此我可以迭代一批子数组。

def MetaClass(object):
    ...    
    def __getitem__(self, inds):
        new_m = MetaClass()
        inds = inds.compute()
        for name,var in vars(self).items():
            if isinstance(var,da.Array):
                try:
                    setattr(new_m, name, var[inds])
                except Exception as e:
                    print(e)
            else:
                setattr(new_m, name, var)
        return new_m

# Here I construct the meta-class to work with some directory.
m = MetaClass('/my/data/...')
# m.type is one of the sub-dask-arrays
m2 = m[m.type==2]

它按预期工作,并且得到了切片的数组,但是结果我得到了巨大的内存消耗,并且我认为dask的机制是在后台复制每个子dask-array的索引。 / p>

我的问题是,如何在不使用大量内存的情况下达到相同的结果?

(我尝试不“计算” getitem 中的“ inds”,但随后我得到了一个不可变形的数组,该数组不能被迭代,这对于应用程序来说是必须的)

我一直在考虑三种可能的解决方案,很高兴得知其中哪一种对我来说是“正确”的解决方案。 (或获得我从未想到的另一种解决方案):

  1. 要使用Dask DataFrame,我不确定如何将其放入多维数组(非常感谢您提供帮助,甚至链接中介绍了如何处理dd中的多维数组)。
  2. 忘记整个MetaClass,并使用一个带有讨厌dtype的dask数组(类似[[“ type”,int,(1,)),(“ images”,np.uint8,(1000, 1000)]]),同样,我对此并不熟悉,并且非常感谢您提供一些帮助(尝试用Google进行搜索。它有点复杂。)
  3. 要与属性及其调用函数机制(https://docs.python.org/2/library/functions.html#property)一起在调用函数( getitem )中作为全局索引共享。但是这里最大的弊端是我失去了数组的类型(对于表示形式以及除了数据本身以外需要任何东西的一切而言,这都是很糟糕的。)

提前谢谢!

1 个答案:

答案 0 :(得分:0)

可以将sub-arrays.map_blocks与共享函数一起使用,该函数将索引保存在其内存中。

这里是一个例子:

        def bool_mask(arr, block_info=None):
            from_ind,to_ind = block_info[0]["array-location"][0]
            return arr[inds[from_ind:to_ind]]

        def getitem(var):
            original_chunks = var.chunks[0]
            tmp_inds = np.cumsum([0]+list(original_chunks))
            from_inds = tmp_inds[:-1]
            to_inds = tmp_inds[1:]
            new_chunks_0 = np.array(list(map(lambda f,t:inds[f:t].sum(),from_inds,to_inds)))
            new_chunks = tuple([tuple(new_chunks_0.tolist())] + list(var.chunks[1:]))
            return var.map_blocks(bool_mask,dtype=var.dtype,chunks=new_chunks)