将pyspark数据帧重塑为Keras / Theano的4维numpy数组

时间:2016-08-26 19:50:59

标签: python numpy apache-spark pyspark keras

我正在尝试将一个火花数据帧,traindf变成一个4-d numpy数组。我试过这个:

traindf = sqlContext.createDataFrame([
    (1, 1, 2, 3),
    (1, 2, 2, 3),
    (1, 3, 2, 3),
    (1, 4, 2, 3),
    (2, 4, 5, 6),
    (2, 4, 5, 6),
    (3, 7, 8, 9),
    (2, 4, 5, 6),
    (3, 7, 8, 9),
    (3, 7, 8, 9)
], ("id", "image", "s", "t"))

values = (traindf.rdd.map(lambda l: [map(lambda r: float(r), l)]).collect())
x = np.array(values)
x = np.array_split(x, x.shape[0]/2)
x = np.asarray(x)
x.shape

这产生(5,2,1,4),但看起来需要keras(5,1,2,4)。我尝试了几种方法,但是没有找到一种获得正确格式的好方法。

有什么建议吗?

1 个答案:

答案 0 :(得分:0)

刚想出来,把它放到最后

x = np.reshape(x, (5, 1, 2, 4))