从GCS读取numy数组到spark

时间:2016-07-12 18:14:27

标签: python google-cloud-storage pyspark google-cloud-dataproc

我在谷歌存储中有100个包含numpy数组的npz文件。 我已经使用jupyter设置了dataproc,我正在尝试将所有numpy数组读入spark RDD。将numpy数组从谷歌存储加载到pyspark的最佳方法是什么? 是否有一种像np.load("gs://path/to/array.npz")这样的简单方法来加载numpy数组,然后对其进行sc.parallelize

1 个答案:

答案 0 :(得分:2)

如果您计划最终扩展,则需要在SparkContext中使用分布式输入方法,而不是依赖sc.parallelize从驱动程序加载任何本地文件。听起来你需要完整地阅读每个文件,所以在你的情况下你想要:

npz_rdd = sc.binaryFiles("gs://path/to/directory/containing/files/")

或者你也可以根据需要指定单个文件,但是你只需要一个带有单个元素的RDD:

npz_rdd = sc.binaryFiles("gs://path/to/directory/containing/files/arr1.npz")

然后每条记录都是一对<filename>,<str of bytes>。在Dataproc上,sc.binaryFiles将自动直接使用GCS路径,而不像需要本地文件系统路径的np.load

然后在您的工作代码中,您只需使用StringIO将这些字节字符串用作放入np.load的文件对象:

from StringIO import StringIO
# For example, to create an RDD of the 'arr_0' element of each of the picked objects:
npz_rdd.map(lambda l: numpy.load(StringIO(l[1]))['arr_0'])

在开发过程中,如果您真的只想将文件读入主驱动程序,可以使用collect()将RDD关闭以在本地检索它:

npz_rdd = sc.binaryFiles("gs://path/to/directory/containing/files/arr1.npz")
local_bytes = npz_rdd.collect()[0][1]
local_np_obj = np.load(StringIO(local_bytes))