我有一个3D数字数据文件,我从块中读取(因为以块的形式读取比单个索引更快)。例如,在'文件中有一个MxNx30阵列,我会像这样创建一个RDD:
def read(ind):
f = customFileOpener(file)
return f['data'][:,:,ind[0]:ind[-1]+1]
indices = [[0,9],[10,19],[20,29]]
rdd = sc.parallelize(indices,3).map(lambda v:read(v))
rdd.count()
因此,3个分区中的每个分区都有一个大小为MxNx10的numpy.ndarray元素。
现在,我想在每个分区中拆分这些元素,我有10个元素,每个元素都是一个MxN数组。我尝试使用flatMap()来实现此目的,但是得到的错误是“NoneType对象不可迭代”':
def splitArr(arr):
Nmid = arr.shape[-1]
out = []
for i in range(0,Nmid):
out.append(arr[...,i])
return out
rdd2 = rdd.flatMap(lambda v: splitArr(v))
rdd2.count()
这样做的正确方法是什么?关键点是(a)我需要从文件中读取块中的数据和(b)拆分数据,使元素的大小为MxN(最好保持分区结构)。
答案 0 :(得分:2)
据我了解你的描述,这样的事情可以解决问题:
rdd.flatMap(lambda arr: (x for x in np.rollaxis(arr, 2)))
或者如果您更喜欢单独的功能:
def splitArr(arr):
for x in np.rollaxis(arr, 2):
yield x
rdd.flatMap(splitArr)