在Pyspark

时间:2015-10-06 05:06:39

标签: python numpy apache-spark pyspark

我有一个3D数字数据文件,我从块中读取(因为以块的形式读取比单个索引更快)。例如,在'文件中有一个MxNx30阵列,我会像这样创建一个RDD:

def read(ind):
    f = customFileOpener(file)
    return f['data'][:,:,ind[0]:ind[-1]+1]

indices = [[0,9],[10,19],[20,29]]
rdd = sc.parallelize(indices,3).map(lambda v:read(v))
rdd.count()

因此,3个分区中的每个分区都有一个大小为MxNx10的numpy.ndarray元素。

现在,我想在每个分区中拆分这些元素,我有10个元素,每个元素都是一个MxN数组。我尝试使用flatMap()来实现此目的,但是得到的错误是“NoneType对象不可迭代”':

def splitArr(arr):
    Nmid = arr.shape[-1]
    out = []
    for i in range(0,Nmid):
         out.append(arr[...,i])
    return out

rdd2 = rdd.flatMap(lambda v: splitArr(v))
rdd2.count()

这样做的正确方法是什么?关键点是(a)我需要从文件中读取块中的数据和(b)拆分数据,使元素的大小为MxN(最好保持分区结构)。

1 个答案:

答案 0 :(得分:2)

据我了解你的描述,这样的事情可以解决问题:

rdd.flatMap(lambda arr: (x for x in np.rollaxis(arr, 2)))

或者如果您更喜欢单独的功能:

def splitArr(arr):
    for x in np.rollaxis(arr, 2):
        yield x

rdd.flatMap(splitArr)